From 69f2ca8cfe04cfeafeffcfd282459819a9c58b62 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 13:11:07 +0800 Subject: [PATCH 01/32] add rotary kernel support to Qwen3 model Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 6 +++ .../models/qwen3/modeling_qwen3.py | 40 ++++++++++++++++++- .../models/qwen3/modular_qwen3.py | 1 + 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..05f580771888 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -115,6 +115,9 @@ register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") + except ImportError: _kernels_available = False @@ -138,6 +141,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +206,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..d06c76c93a44 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -117,6 +118,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -192,6 +221,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -202,7 +232,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if use_kernels and rotary_kernel: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -252,6 +285,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -265,6 +299,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -362,6 +397,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -415,6 +451,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) @@ -485,6 +522,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=self.use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..ef764b0e90b0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -19,6 +19,7 @@ import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import rotary_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS From d2bf5c56be2d7104d04f99297b90c110deb68b30 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 15:05:14 +0800 Subject: [PATCH 02/32] delete unnecessary import Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3/modular_qwen3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index ef764b0e90b0..f1e38841faf4 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -19,7 +19,6 @@ import torch from ...cache_utils import Cache -from ...integrations.hub_kernels import rotary_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS From b0cbab5640650cef7328cf592fabb717aaf0cbca Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 08:19:18 +0000 Subject: [PATCH 03/32] adjust code Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 6 ------ src/transformers/modeling_utils.py | 8 +++++++- src/transformers/models/qwen3/modeling_qwen3.py | 11 ++--------- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 05f580771888..5be21e2f9a51 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -115,9 +115,6 @@ register_kernel_mapping(_KERNEL_MAPPING) - # Preload the rotary kernel as it's used in many models. - rotary_kernel = get_kernel(repo_id="kernels-community/rotary") - except ImportError: _kernels_available = False @@ -141,8 +138,6 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") - rotary_kernel = None - def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -206,5 +201,4 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", - "rotary_kernel", ] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a64085c4e931..b3e55f8abcee 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -173,6 +173,9 @@ _is_quantized = False _is_ds_init_called = False +# Initialize rotary_kernel as None, will be set when kernelize() is called +rotary_kernel = None + def is_local_dist_rank_0(): return ( @@ -5788,10 +5791,13 @@ def kernelize(self): raise ValueError( "Kernels are not available. To use kernels, please install kernels using `pip install kernels`" ) - from kernels import Device, Mode, kernelize + from kernels import Device, Mode, get_kernel, kernelize mode = Mode.INFERENCE if not self.training else Mode.TRAINING kernelize(self, device=Device(type=self.device.type), mode=mode) + # Preload the rotary kernel as it's used in many models. + global rotary_kernel + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") self._use_kernels = True @property diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index d06c76c93a44..76cf1fd7e758 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,7 +28,6 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -39,7 +38,7 @@ ) from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, rotary_kernel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg @@ -221,7 +220,6 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -232,7 +230,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - if use_kernels and rotary_kernel: + if rotary_kernel: query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -285,7 +283,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -299,7 +296,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -397,7 +393,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -451,7 +446,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - use_kernels=use_kernels, **kwargs, ) @@ -522,7 +516,6 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - use_kernels=self.use_kernels, **kwargs, ) From 8dede65dba715890970e4ced0334bb9296b5da52 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 10:04:48 +0000 Subject: [PATCH 04/32] adjust code Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3/modeling_qwen3.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index d06c76c93a44..3b6a26aa9101 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -47,6 +47,10 @@ from .configuration_qwen3 import Qwen3Config +# Global variable to track kernel usage, set by model instances +use_kernels = False + + @use_kernel_forward_from_hub("RMSNorm") class Qwen3RMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: @@ -221,7 +225,6 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -285,7 +288,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -299,7 +301,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -397,7 +398,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -451,7 +451,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - use_kernels=use_kernels, **kwargs, ) @@ -514,6 +513,10 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + # Set global use_kernels flag based on model's kernel usage + global use_kernels + use_kernels = getattr(self, "use_kernels", False) + outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -522,7 +525,6 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - use_kernels=self.use_kernels, **kwargs, ) From 137069b0f3ccf6fb27dad74f8c394af981b5fcea Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 10:12:02 +0000 Subject: [PATCH 05/32] put get rotary kernel to hub_kernels.py Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 5 +++++ src/transformers/modeling_utils.py | 8 +------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..e5f89589442e 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -114,6 +114,8 @@ } register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") except ImportError: _kernels_available = False @@ -138,6 +140,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +205,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b3e55f8abcee..a64085c4e931 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -173,9 +173,6 @@ _is_quantized = False _is_ds_init_called = False -# Initialize rotary_kernel as None, will be set when kernelize() is called -rotary_kernel = None - def is_local_dist_rank_0(): return ( @@ -5791,13 +5788,10 @@ def kernelize(self): raise ValueError( "Kernels are not available. To use kernels, please install kernels using `pip install kernels`" ) - from kernels import Device, Mode, get_kernel, kernelize + from kernels import Device, Mode, kernelize mode = Mode.INFERENCE if not self.training else Mode.TRAINING kernelize(self, device=Device(type=self.device.type), mode=mode) - # Preload the rotary kernel as it's used in many models. - global rotary_kernel - rotary_kernel = get_kernel(repo_id="kernels-community/rotary") self._use_kernels = True @property From 8ac3e1ea1e6125096582d73159eb40cb88c0fe1d Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 10:19:59 +0000 Subject: [PATCH 06/32] fix wrong import Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3/modeling_qwen3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index a5eeb3ab90c5..b01c0f5f9815 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -38,7 +39,7 @@ ) from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, rotary_kernel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg From 29f83f232829f282199911efb591f8797b633198 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 26 Sep 2025 02:51:28 +0000 Subject: [PATCH 07/32] refine code and adjust related modular code Signed-off-by: Liu, Kaixuan --- .../models/qwen3/modeling_qwen3.py | 77 +++++++++-------- .../models/qwen3/modular_qwen3.py | 86 +++++++++++++++++-- 2 files changed, 120 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index b01c0f5f9815..8f34c0ebc574 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,7 +28,6 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -47,10 +46,6 @@ from .configuration_qwen3 import Qwen3Config -# Global variable to track kernel usage, set by model instances -use_kernels = False - - @use_kernel_forward_from_hub("RMSNorm") class Qwen3RMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: @@ -122,34 +117,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -188,6 +155,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -235,8 +232,14 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - if rotary_kernel: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -513,10 +516,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - # Set global use_kernels flag based on model's kernel usage - global use_kernels use_kernels = getattr(self, "use_kernels", False) - outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -525,6 +525,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..65babb1e3eb6 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -14,13 +14,13 @@ # limitations under the License. """PyTorch Qwen3 model.""" -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging @@ -49,6 +49,36 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3RMSNorm(Qwen2RMSNorm): pass @@ -82,7 +112,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -125,7 +164,16 @@ class Qwen3Model(Qwen2Model): class Qwen3ForCausalLM(Qwen2ForCausalLM): def forward( self, - **super_kwargs: Unpack[TransformersKwargs], + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -149,7 +197,35 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - return super().forward(**super_kwargs) + use_kernels = getattr(self, "use_kernels", False) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + use_kernels=use_kernels, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) class Qwen3ForSequenceClassification(Qwen2ForSequenceClassification): From 94e4f60846ad87829202ad2d827ac984464ae06a Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 26 Sep 2025 07:58:26 +0000 Subject: [PATCH 08/32] fix modular mismatch bug Signed-off-by: Liu, Kaixuan --- .../models/dots1/modeling_dots1.py | 43 ++++++++++++- .../models/qwen3_moe/modeling_qwen3_moe.py | 41 +++++++++++- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 63 ++++++++++++++++++- 3 files changed, 142 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index ea500c064512..4f419e1f1cbb 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -170,6 +170,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Dots1Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -217,7 +247,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -580,6 +619,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -588,6 +628,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 2056e7c76a3a..174b2854e8af 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -119,6 +119,36 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -166,7 +196,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1172ebf90919..8a1e0b786a27 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1397,6 +1397,36 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + from ...integrations.hub_kernels import rotary_kernel + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + class Qwen3OmniMoeThinkerTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1448,7 +1478,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2324,7 +2363,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3375,7 +3423,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Check if use_kernels is passed in kwargs + use_kernels = kwargs.get("use_kernels", False) + if use_kernels: + try: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + except (ImportError, AttributeError, RuntimeError): + # Fallback to regular rotary position embedding if kernel is not available + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From af67a74027039942983b4cd80340e4bcafe0239a Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 21 Oct 2025 06:30:00 +0000 Subject: [PATCH 09/32] update code, use lazy load kernels Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 + .../models/dots1/modeling_dots1.py | 16 +++--- .../models/qwen3/modeling_qwen3.py | 46 +++-------------- .../models/qwen3/modular_qwen3.py | 50 +++---------------- .../models/qwen3_moe/modeling_qwen3_moe.py | 16 +++--- src/transformers/utils/import_utils.py | 5 ++ 6 files changed, 35 insertions(+), 99 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 84a478adff99..2c9dfd3eb25a 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -168,6 +168,7 @@ def register_kernel_mapping(*args, **kwargs): _HUB_KERNEL_MAPPING: dict[str, str] = { "causal-conv1d": "kernels-community/causal-conv1d", + "rotary_emb": "kernels-community/rotary", } _KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 2818cdfd2073..a390fcf495b0 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -259,6 +259,11 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + self._rotary_kernel_loaded = lazy_load_kernel("rotary_emb") + def forward( self, hidden_states: torch.Tensor, @@ -276,14 +281,9 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Use the cached kernel loaded during initialization + if self._rotary_kernel_loaded is not None: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 7b3e2d5fb045..967c99df5d44 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -220,36 +221,6 @@ def eager_attention_forward( return attn_output, attn_weights -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - from ...integrations.hub_kernels import rotary_kernel - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -280,6 +251,10 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + # Load and cache the rotary kernel once during initialization to improve performance + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -297,16 +272,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 15ee7c89af29..925708409d2b 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -14,12 +14,13 @@ # limitations under the License. """PyTorch Qwen3 model.""" -from typing import Callable, Optional, Union from collections.abc import Callable +from typing import Optional, Union import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -47,36 +48,6 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B" -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - from ...integrations.hub_kernels import rotary_kernel - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - class Qwen3RMSNorm(Qwen2RMSNorm): pass @@ -97,6 +68,10 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + # Load and cache the rotary kernel once during initialization to improve performance + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -114,16 +89,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -187,7 +153,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -196,7 +161,6 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 55d5bd353f21..948c6d22d62c 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -178,6 +178,11 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + self._rotary_kernel_loaded = lazy_load_kernel("rotary_emb") + def forward( self, hidden_states: torch.Tensor, @@ -195,14 +200,9 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Use the cached kernel loaded during initialization + if self._rotary_kernel_loaded is not None: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index e2b50ae7054a..3d92fe6211f1 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -657,6 +657,11 @@ def is_flash_linear_attention_available(): return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2") +@lru_cache +def is_rotary_emb_available() -> bool: + return is_torch_xpu_available() and _is_package_available("rotary_emb") + + @lru_cache def is_causal_conv1d_available() -> bool: return is_torch_cuda_available() and _is_package_available("causal_conv1d") From 28c69d3f5c4c4a35a53be67386c913ddcca9da3d Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 21 Oct 2025 15:25:54 +0000 Subject: [PATCH 10/32] fix check modular conversion issue Signed-off-by: Liu, Kaixuan --- .../models/dots1/modeling_dots1.py | 41 +--------- .../models/qwen3/modeling_qwen3.py | 5 +- .../models/qwen3/modular_qwen3.py | 3 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 39 +-------- .../models/qwen3_next/modeling_qwen3_next.py | 26 +++--- .../models/qwen3_next/modular_qwen3_next.py | 4 +- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 81 +++++-------------- .../models/qwen3_vl/modeling_qwen3_vl.py | 8 +- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 8 +- 10 files changed, 57 insertions(+), 161 deletions(-) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index a390fcf495b0..80318c781757 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -199,36 +199,6 @@ def eager_attention_forward( return attn_output, attn_weights -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - from ...integrations.hub_kernels import rotary_kernel - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - class Dots1Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -262,7 +232,8 @@ def __init__(self, config: Dots1Config, layer_idx: int): # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel - self._rotary_kernel_loaded = lazy_load_kernel("rotary_emb") + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, @@ -281,11 +252,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Use the cached kernel loaded during initialization - if self._rotary_kernel_loaded is not None: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -648,7 +615,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -657,7 +623,6 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 967c99df5d44..81cdf8240b85 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -29,7 +29,6 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -252,6 +251,8 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb @@ -508,7 +509,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - use_kernels = getattr(self, "use_kernels", False) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -517,7 +517,6 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - use_kernels=use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 925708409d2b..2ee6a704e611 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -20,7 +20,6 @@ import torch from ...cache_utils import Cache -from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -69,6 +68,8 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 948c6d22d62c..59a8a6cb8c83 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -119,36 +119,6 @@ def eager_attention_forward( return attn_output, attn_weights -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - from ...integrations.hub_kernels import rotary_kernel - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -181,7 +151,8 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel - self._rotary_kernel_loaded = lazy_load_kernel("rotary_emb") + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, @@ -200,11 +171,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Use the cached kernel loaded during initialization - if self._rotary_kernel_loaded is not None: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 036d2a30a55e..82ec886ff040 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -267,12 +267,9 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. - Removes the interleaving of cos and sin from GLM - Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. @@ -292,19 +289,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - - # Keep half or full tensor for later concatenation - rotary_dim = cos.shape[-1] - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -375,6 +361,12 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -397,7 +389,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 1e9962797775..f566fef0684b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -34,7 +34,7 @@ is_causal_conv1d_available, is_flash_linear_attention_available, ) -from ..bamba.modeling_bamba import apply_mask_to_padding_states, apply_rotary_pos_emb +from ..bamba.modeling_bamba import apply_mask_to_padding_states from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..gemma3.modeling_gemma3 import Gemma3RMSNorm from ..llama.modeling_llama import ( @@ -249,7 +249,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 334fec06fa2c..aaa89114bbd5 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1430,36 +1430,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Rotary kernel implementation wrapper - Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature - """ - from ...integrations.hub_kernels import rotary_kernel - - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - q_rotated = q.clone() - k_rotated = k.clone() - - # Get half dimension for rotation - half_dim = q.shape[-1] // 2 - q1 = q_rotated[..., :half_dim] - q2 = q_rotated[..., half_dim:] - k1 = k_rotated[..., :half_dim] - k2 = k_rotated[..., half_dim:] - if cos.shape[-1] != half_dim: - # Trim cos/sin to match half_dim - cos = cos[..., :half_dim] - sin = sin[..., :half_dim] - - # Apply rotary embedding using our kernel - rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return q_rotated, k_rotated - - class Qwen3OmniMoeThinkerTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1493,6 +1463,12 @@ def __init__(self, config, layer_idx): ) # thus post q_norm does not need reshape self.sliding_window = None + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -1510,16 +1486,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2348,6 +2315,12 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): ) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -2365,16 +2338,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3431,6 +3395,12 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.k_norm = nn.Identity() self.sliding_window = config.sliding_window + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -3448,16 +3418,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - # Check if use_kernels is passed in kwargs - use_kernels = kwargs.get("use_kernels", False) - if use_kernels: - try: - query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) - except (ImportError, AttributeError, RuntimeError): - # Fallback to regular rotary position embedding if kernel is not available - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index d41cfa4b090e..03a3e2d2e7f4 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -443,6 +443,12 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -460,7 +466,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 5d1c88d03bc4..8a9a1304f35e 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -57,7 +57,6 @@ Qwen3Attention, Qwen3DecoderLayer, Qwen3Model, - apply_rotary_pos_emb, eager_attention_forward, ) @@ -423,7 +422,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 264902c2d8a4..1a9f9feec330 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -257,6 +257,12 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape + # Load and cache the rotary kernel once during initialization to improve performance + from ...integrations.hub_kernels import lazy_load_kernel + + rotary_kernel = lazy_load_kernel("rotary_emb") + self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -274,7 +280,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From 7ac16a1719072e9b829bc5f634b0bba52d88622b Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 22 Oct 2025 06:07:21 +0000 Subject: [PATCH 11/32] fix CI bug for qwen3-next Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 - .../models/qwen3_next/modeling_qwen3_next.py | 23 ++++++++++++++++--- .../models/qwen3_next/modular_qwen3_next.py | 6 ++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 872a5c310a45..f9ae5bbd9f59 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -166,7 +166,6 @@ def register_kernel_mapping(*args, **kwargs): rotary_kernel = None - _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, "rotary_emb": {"repo_id": "kernels-community/rotary"}, diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 82ec886ff040..6c4b5872a17c 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -267,9 +267,12 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. + Removes the interleaving of cos and sin from GLM + Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. @@ -289,8 +292,19 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed @@ -365,7 +379,10 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + + # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet + # So we use the bamba apply_rotary_pos_emb function (imported at top) directly + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index f566fef0684b..a49db3ffb3c7 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -34,7 +34,7 @@ is_causal_conv1d_available, is_flash_linear_attention_available, ) -from ..bamba.modeling_bamba import apply_mask_to_padding_states +from ..bamba.modeling_bamba import apply_mask_to_padding_states, apply_rotary_pos_emb from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..gemma3.modeling_gemma3 import Gemma3RMSNorm from ..llama.modeling_llama import ( @@ -227,6 +227,10 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): ) del self.sliding_window + # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet + # So we use the bamba apply_rotary_pos_emb function (imported at top) directly + self.rotary_fn = apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, From adce121a2ea7a6e08f70ed67b7a487c8ad3eb717 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 22 Oct 2025 06:46:55 +0000 Subject: [PATCH 12/32] fix CI issue Signed-off-by: Liu, Kaixuan --- .../models/qwen3_next/modeling_qwen3_next.py | 13 ++++------- .../models/qwen3_next/modular_qwen3_next.py | 23 +++++++++++++++++-- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 6c4b5872a17c..38afef9518ca 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -350,6 +350,7 @@ class Qwen3NextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3NextConfig, layer_idx: int): + # Initialize nn.Module (skip Qwen3MoeAttention.__init__ to avoid loading rotary kernel) super().__init__() self.config = config self.layer_idx = layer_idx @@ -358,6 +359,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True + self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) @@ -370,15 +372,8 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = Qwen3NextRMSNorm( - self.head_dim, eps=config.rms_norm_eps - ) # thus post q_norm does not need reshape - - # Load and cache the rotary kernel once during initialization to improve performance - from ...integrations.hub_kernels import lazy_load_kernel - - rotary_kernel = lazy_load_kernel("rotary_emb") + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet # So we use the bamba apply_rotary_pos_emb function (imported at top) directly diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index a49db3ffb3c7..b16c9a044ac4 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -221,11 +221,30 @@ class Qwen3NextRMSNorm(Gemma3RMSNorm): class Qwen3NextAttention(Qwen3MoeAttention): def __init__(self, config: Qwen3NextConfig, layer_idx: int): - super().__init__(config, layer_idx) + # Initialize nn.Module (skip Qwen3MoeAttention.__init__ to avoid loading rotary kernel) + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) - del self.sliding_window + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet # So we use the bamba apply_rotary_pos_emb function (imported at top) directly From b4757f4565f5c5c50c4f2b271278c86767b4e77b Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 22 Oct 2025 07:19:36 +0000 Subject: [PATCH 13/32] delete unused code Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 5 ----- src/transformers/models/dots1/modeling_dots1.py | 1 - src/transformers/models/qwen3/modeling_qwen3.py | 1 - src/transformers/models/qwen3/modular_qwen3.py | 1 - src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 1 - .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 3 --- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 1 - .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 1 - 8 files changed, 14 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index f9ae5bbd9f59..28b3659cee6f 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -137,8 +137,6 @@ } register_kernel_mapping(_KERNEL_MAPPING) - # Preload the rotary kernel as it's used in many models. - rotary_kernel = get_kernel(repo_id="kernels-community/rotary") except ImportError: _kernels_available = False @@ -163,8 +161,6 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") - rotary_kernel = None - _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, @@ -286,6 +282,5 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", - "rotary_kernel", "lazy_load_kernel", ] diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 80318c781757..e99da6a77356 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -229,7 +229,6 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81cdf8240b85..14c8d0fcb610 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -250,7 +250,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 2ee6a704e611..a8a237a09fd0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -67,7 +67,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 59a8a6cb8c83..cabd435bbea7 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -148,7 +148,6 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 6e76e316de08..0a989a18c8e4 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1463,7 +1463,6 @@ def __init__(self, config, layer_idx): ) # thus post q_norm does not need reshape self.sliding_window = None - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") @@ -2322,7 +2321,6 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): ) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") @@ -3402,7 +3400,6 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.k_norm = nn.Identity() self.sliding_window = config.sliding_window - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 03a3e2d2e7f4..e917f956b1bb 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -443,7 +443,6 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 1a9f9feec330..c2f0cdb84059 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -257,7 +257,6 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - # Load and cache the rotary kernel once during initialization to improve performance from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") From 3bbfd64693d0897453c4bd5d0c6a13357639b330 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 23 Oct 2025 01:34:01 +0000 Subject: [PATCH 14/32] rename to `apply_rotary_transformers` Signed-off-by: Liu, Kaixuan --- src/transformers/models/dots1/modeling_dots1.py | 2 +- src/transformers/models/qwen3/modeling_qwen3.py | 2 +- src/transformers/models/qwen3/modular_qwen3.py | 2 +- src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 2 +- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 +++--- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index e99da6a77356..ac6d254b12ed 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -232,7 +232,7 @@ def __init__(self, config: Dots1Config, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 14c8d0fcb610..ff1fe9847af2 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -253,7 +253,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index a8a237a09fd0..a4453bc16408 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -70,7 +70,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index cabd435bbea7..a8b9099b0440 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -151,7 +151,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 0a989a18c8e4..1f60108fcdc7 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1466,7 +1466,7 @@ def __init__(self, config, layer_idx): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, @@ -2324,7 +2324,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, @@ -3403,7 +3403,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index e917f956b1bb..ace51914f9e4 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -446,7 +446,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index c2f0cdb84059..a4289a7f1ca4 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -260,7 +260,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): from ...integrations.hub_kernels import lazy_load_kernel rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_kernel if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb def forward( self, From caa549a62810ad090040cacff3073956c301b5e9 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 29 Oct 2025 07:28:25 +0000 Subject: [PATCH 15/32] adjust import `lazy_load_kernel` location Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3/modeling_qwen3.py | 3 +-- src/transformers/models/qwen3/modular_qwen3.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index ff1fe9847af2..cb8f5c32394b 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -250,8 +251,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index a4453bc16408..2c32f600ba86 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -20,6 +20,7 @@ import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import lazy_load_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -67,8 +68,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb From 08a99598c521c618d04905de0ccdcb6f62ed6250 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 29 Oct 2025 07:45:41 +0000 Subject: [PATCH 16/32] Update modular-generated modeling files with lazy_load_kernel import location Signed-off-by: Liu, Kaixuan --- src/transformers/models/dots1/modeling_dots1.py | 3 +-- src/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 3 +-- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 3 +-- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index ac6d254b12ed..335f47a1f44c 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -229,8 +230,6 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index a8b9099b0440..f3a17b7bf8ff 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -30,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -148,8 +149,6 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index ace51914f9e4..bf9fcc750b8c 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -443,8 +444,6 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index a4289a7f1ca4..02f00a4c4009 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -257,8 +258,6 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb From 7915fc1840c279efb516c6d4ef42ab421490a47c Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 29 Oct 2025 07:58:10 +0000 Subject: [PATCH 17/32] fix conflicts Signed-off-by: Liu, Kaixuan --- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1f60108fcdc7..337201fa4bc5 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -35,6 +35,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import lazy_load_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1463,8 +1464,6 @@ def __init__(self, config, layer_idx): ) # thus post q_norm does not need reshape self.sliding_window = None - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb @@ -2321,8 +2320,6 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): ) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb @@ -3400,8 +3397,6 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.k_norm = nn.Identity() self.sliding_window = config.sliding_window - from ...integrations.hub_kernels import lazy_load_kernel - rotary_kernel = lazy_load_kernel("rotary_emb") self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb From 6f2d95888a3b6e94db76591a69649a5e9b28c4db Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 31 Oct 2025 11:02:38 +0000 Subject: [PATCH 18/32] add more check Signed-off-by: Liu, Kaixuan --- .../models/dots1/modeling_dots1.py | 8 ++++++- .../models/qwen3/modeling_qwen3.py | 8 ++++++- .../models/qwen3/modular_qwen3.py | 8 ++++++- .../models/qwen3_moe/modeling_qwen3_moe.py | 8 ++++++- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 24 ++++++++++++++++--- .../models/qwen3_vl/modeling_qwen3_vl.py | 8 ++++++- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 8 ++++++- 7 files changed, 63 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 335f47a1f44c..c28eac801143 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -231,7 +231,13 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index cb8f5c32394b..7ae4f87e63d2 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -252,7 +252,13 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 2c32f600ba86..a20187d8fa63 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -69,7 +69,13 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f3a17b7bf8ff..9ef51c9f2415 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -150,7 +150,13 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.sliding_window = getattr(config, "sliding_window", None) rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 337201fa4bc5..cf12f94b77bf 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1465,7 +1465,13 @@ def __init__(self, config, layer_idx): self.sliding_window = None rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, @@ -2321,7 +2327,13 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, @@ -3398,7 +3410,13 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.sliding_window = config.sliding_window rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index bf9fcc750b8c..5f7ccf04fbf1 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -445,7 +445,13 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): ) # thus post q_norm does not need reshape rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 02f00a4c4009..b75ce7c50d05 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -259,7 +259,13 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): ) # thus post q_norm does not need reshape rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = rotary_kernel.apply_rotary_transformers if rotary_kernel is not None else apply_rotary_pos_emb + self.rotary_fn = ( + rotary_kernel.apply_rotary_transformers + if rotary_kernel is not None + and hasattr(rotary_kernel, "apply_rotary_transformers") + and rotary_kernel.apply_rotary_transformers is not None + else apply_rotary_pos_emb + ) def forward( self, From f4b12a72f617e693285b8cf75c996e70faf6659e Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Nov 2025 07:56:23 +0000 Subject: [PATCH 19/32] use decorator to map kernels for functions Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/hub_kernels.py | 21 ++++++- .../models/dots1/modeling_dots1.py | 14 +---- .../models/qwen3/modeling_qwen3.py | 14 +---- .../models/qwen3/modular_qwen3.py | 56 +++---------------- .../models/qwen3_moe/modeling_qwen3_moe.py | 14 +---- .../models/qwen3_next/modeling_qwen3_next.py | 16 +++--- .../models/qwen3_next/modular_qwen3_next.py | 32 +---------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 36 +++--------- .../models/qwen3_vl/modeling_qwen3_vl.py | 16 ++---- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 16 ++---- 12 files changed, 65 insertions(+), 175 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 237d7420997f..72e047c2777e 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -72,6 +72,7 @@ "register_kernel_mapping", "replace_kernel_forward_from_hub", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", ], "integration_utils": [ "INTEGRATION_TO_CALLBACK", @@ -212,6 +213,7 @@ register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, + use_kernel_func_from_hub, ) from .integration_utils import ( INTEGRATION_TO_CALLBACK, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 230b14161e71..9ec2336e202f 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -34,6 +34,7 @@ register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, + use_kernel_func_from_hub, ) _kernels_available = True @@ -146,6 +147,16 @@ }, } + _HUB_FUNC_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { + "rotary_fn": { + "xpu": { + Mode.INFERENCE: LayerRepository( + repo_id="kernels-community/rotary", layer_name="apply_rotary_transformers" + ) + } + }, + } + def has_key(d, key): return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values()) @@ -157,6 +168,7 @@ def register_kernel_mapping_transformers(mapping=None): "kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`." ) register_kernel_mapping(mapping) + register_kernel_mapping(_HUB_FUNC_KERNEL_MAPPING) except ImportError: @@ -170,6 +182,12 @@ def decorator(cls): return decorator + def use_kernel_func_from_hub(*args, **kwargs): + def decorator(cls): + return cls + + return decorator + class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") @@ -182,10 +200,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") - _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, - "rotary_emb": {"repo_id": "kernels-community/rotary"}, } _KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {} @@ -301,6 +317,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]] __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index b21662c0c49d..c0750b224dd5 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -28,8 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -200,6 +199,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_fn") class Dots1Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -229,15 +229,7 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 7ae4f87e63d2..8ee3636c1162 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,8 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -221,6 +220,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_fn") class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -250,15 +250,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index a20187d8fa63..62c814742b3d 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -15,14 +15,14 @@ """PyTorch Qwen3 model.""" from collections.abc import Callable -from typing import Optional, Union +from typing import Optional import torch from ...cache_utils import Cache -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_func_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging @@ -60,6 +60,7 @@ class Qwen3RotaryEmbedding(Qwen2RotaryEmbedding): pass +@use_kernel_func_from_hub("rotary_fn") class Qwen3Attention(LlamaAttention): def __init__(self, config: Qwen3Config, layer_idx: int): self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None @@ -67,15 +68,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -125,16 +118,7 @@ def forward( class Qwen3ForCausalLM(Qwen2ForCausalLM): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + **super_kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -158,33 +142,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + return super().forward(**super_kwargs) class Qwen3ForSequenceClassification(Qwen2ForSequenceClassification): diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 0da36f90f47e..58ac329e0a5e 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -29,8 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -120,6 +119,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_fn") class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -148,15 +148,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index b540e6b6026e..867b359f71fa 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -346,11 +347,11 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_fn") class Qwen3NextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3NextConfig, layer_idx: int): - # Initialize nn.Module (skip Qwen3MoeAttention.__init__ to avoid loading rotary kernel) super().__init__() self.config = config self.layer_idx = layer_idx @@ -361,7 +362,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.is_causal = True self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias @@ -372,11 +373,10 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet - # So we use the bamba apply_rotary_pos_emb function (imported at top) directly + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3NextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape self.rotary_fn = apply_rotary_pos_emb def forward( @@ -401,7 +401,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index ec62b8df8656..3f113b44177e 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -221,34 +221,8 @@ class Qwen3NextRMSNorm(Gemma3RMSNorm): class Qwen3NextAttention(Qwen3MoeAttention): def __init__(self, config: Qwen3NextConfig, layer_idx: int): - # Initialize nn.Module (skip Qwen3MoeAttention.__init__ to avoid loading rotary kernel) - nn.Module.__init__(self) - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - # qwen3_next uses partial rotary embeddings, which the rotary kernel doesn't support yet - # So we use the bamba apply_rotary_pos_emb function (imported at top) directly - self.rotary_fn = apply_rotary_pos_emb + super().__init__(config, layer_idx) + del self.sliding_window def forward( self, @@ -272,7 +246,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 0706c98967d5..aae01c8f7f7c 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -34,8 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1431,6 +1430,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeThinkerTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1463,15 +1463,7 @@ def __init__(self, config, layer_idx): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape self.sliding_window = None - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -2294,6 +2286,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -2325,15 +2318,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3318,6 +3303,7 @@ def forward(self, hidden_states): return hidden_states +@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeCode2WavAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -3347,15 +3333,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.sliding_window = config.sliding_window - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 6f2dd9fe43fb..e30cef8a4141 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -30,8 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -413,6 +412,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +@use_kernel_func_from_hub("rotary_fn") class Qwen3VLTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -443,15 +443,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.k_norm = Qwen3VLTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -470,7 +462,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 3e4d4adeb95f..67a6de3b2723 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -57,6 +57,7 @@ Qwen3Attention, Qwen3DecoderLayer, Qwen3Model, + apply_rotary_pos_emb, eager_attention_forward, ) @@ -422,7 +423,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 3ae9bf21786a..ebbd7adea247 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -30,8 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub -from ...integrations.hub_kernels import lazy_load_kernel +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -225,6 +224,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +@use_kernel_func_from_hub("rotary_fn") class Qwen3VLMoeTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -257,15 +257,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.k_norm = Qwen3VLMoeTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - - rotary_kernel = lazy_load_kernel("rotary_emb") - self.rotary_fn = ( - rotary_kernel.apply_rotary_transformers - if rotary_kernel is not None - and hasattr(rotary_kernel, "apply_rotary_transformers") - and rotary_kernel.apply_rotary_transformers is not None - else apply_rotary_pos_emb - ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -284,7 +276,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From 6a3e6f3487d1c3aaa55b1af77792a45f4d42b9e1 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Nov 2025 08:07:42 +0000 Subject: [PATCH 20/32] small fix Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3_next/modeling_qwen3_next.py | 3 +-- src/transformers/models/qwen3_next/modular_qwen3_next.py | 3 +++ src/transformers/utils/import_utils.py | 5 ----- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 867b359f71fa..c4dd7858f0d0 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -360,9 +360,8 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 3f113b44177e..e624a653150b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -222,6 +222,9 @@ class Qwen3NextRMSNorm(Gemma3RMSNorm): class Qwen3NextAttention(Qwen3MoeAttention): def __init__(self, config: Qwen3NextConfig, layer_idx: int): super().__init__(config, layer_idx) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias + ) del self.sliding_window def forward( diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 55df45e946d8..b38ea64cc4ff 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -660,11 +660,6 @@ def is_flash_linear_attention_available(): return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2") -@lru_cache -def is_rotary_emb_available() -> bool: - return is_torch_xpu_available() and _is_package_available("rotary_emb") - - @lru_cache def is_causal_conv1d_available() -> bool: return is_torch_cuda_available() and _is_package_available("causal_conv1d") From 702dc0905515979327e6dd2dd3c2626e6117ecdf Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Nov 2025 08:20:17 +0000 Subject: [PATCH 21/32] small adjustment Signed-off-by: Liu, Kaixuan --- src/transformers/models/qwen3_next/modeling_qwen3_next.py | 2 +- src/transformers/models/qwen3_next/modular_qwen3_next.py | 3 ++- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 3 ++- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 4 ++-- src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py | 4 +++- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index c4dd7858f0d0..52826ba819c2 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -400,7 +400,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index e624a653150b..30a12ed1c355 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -226,6 +226,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) del self.sliding_window + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -249,7 +250,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index e30cef8a4141..2fcd017341e7 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -462,7 +462,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 67a6de3b2723..5d7e67d3c1a9 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -405,6 +405,7 @@ class Qwen3VLTextAttention(Qwen3Attention): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.sliding_window + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -423,7 +424,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ebbd7adea247..5c96419dbeca 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -228,7 +228,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class Qwen3VLMoeTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): + def __init__(self, config, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config @@ -276,7 +276,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index c0c4be2ddb68..343e77019305 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -347,7 +347,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen3VLMoeTextAttention(Qwen3VLTextAttention): - pass + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.rotary_fn = apply_rotary_pos_emb class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer): From fe2bf4234e31732738cb00dbb1d4d03fdb97f675 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Nov 2025 08:46:05 +0000 Subject: [PATCH 22/32] update code Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 11 ----------- .../models/qwen3_vl_moe/modular_qwen3_vl_moe.py | 1 + 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b01aa283db38..2734601d93ff 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -42,17 +42,6 @@ _kernels_available = True _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES - def use_kernel_forward_from_hub(layer_name: str): - if _kernels_enabled: - from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub - - return _kernels_use_kernel_forward_from_hub(layer_name) - else: - logger.warning_once( - f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" - ) - return lambda cls: cls - _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { "MultiScaleDeformableAttention": { "cuda": LayerRepository( diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index fee01709d9e4..445229ecfb97 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -41,6 +41,7 @@ Qwen3VLTextAttention, Qwen3VLTextModel, Qwen3VLVisionModel, + apply_rotary_pos_emb, ) From 6f9596957966248f5f7e91f1d1667e2d936f8aed Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Nov 2025 08:54:09 +0000 Subject: [PATCH 23/32] fix LINT issue Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 2734601d93ff..ffc3749499fe 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -210,6 +210,7 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, } From e80cfd3d6e1bb10ac8de3e4eb1d610ca2629c43c Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 27 Nov 2025 09:18:59 +0000 Subject: [PATCH 24/32] update code to adapt to new `use_kernel_func_from_hub` API in kernels Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 19 ++++--------------- .../models/qwen3/modeling_qwen3.py | 2 +- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index a834aea9194f..c0179d362ab2 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -28,6 +28,7 @@ try: from kernels import ( Device, + FuncRepository, LayerRepository, Mode, get_kernel, @@ -41,7 +42,6 @@ _kernels_available = True _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES - def use_kernel_forward_from_hub(layer_name: str): if _kernels_enabled: from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub @@ -165,13 +165,10 @@ def use_kernel_forward_from_hub(layer_name: str): ) } }, - } - - _HUB_FUNC_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { - "rotary_fn": { + "rotary_pos_emb": { "xpu": { - Mode.INFERENCE: LayerRepository( - repo_id="kernels-community/rotary", layer_name="apply_rotary_transformers" + Mode.INFERENCE: FuncRepository( + repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" ) } }, @@ -188,7 +185,6 @@ def register_kernel_mapping_transformers(mapping=None): "kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`." ) register_kernel_mapping(mapping) - register_kernel_mapping(_HUB_FUNC_KERNEL_MAPPING) except ImportError: @@ -203,12 +199,6 @@ def decorator(cls): return decorator - def use_kernel_func_from_hub(*args, **kwargs): - def decorator(cls): - return cls - - return decorator - class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") @@ -339,7 +329,6 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", - "use_kernel_func_from_hub", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index e5b020f2c6ea..2fb330dca4b7 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -155,6 +155,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -220,7 +221,6 @@ def eager_attention_forward( return attn_output, attn_weights -@use_kernel_func_from_hub("rotary_fn") class Qwen3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" From c771acd37bd7f1e72a0500661e4543908fdcde37 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 27 Nov 2025 09:40:20 +0000 Subject: [PATCH 25/32] do not consider check_modular first Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 2 +- src/transformers/models/dots1/modeling_dots1.py | 6 ++---- src/transformers/models/qwen3/modular_qwen3.py | 5 +---- .../models/qwen3_moe/modeling_qwen3_moe.py | 6 ++---- .../models/qwen3_next/modeling_qwen3_next.py | 5 +---- .../models/qwen3_next/modular_qwen3_next.py | 3 +-- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 14 ++++---------- .../models/qwen3_vl/modeling_qwen3_vl.py | 6 ++---- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +-- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 8 +++----- .../models/qwen3_vl_moe/modular_qwen3_vl_moe.py | 5 +---- 11 files changed, 19 insertions(+), 44 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index c0179d362ab2..6b0346ea83df 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -34,7 +34,7 @@ get_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, - use_kernel_forward_from_hub, + # use_kernel_forward_from_hub, use_kernel_func_from_hub, ) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index d0f2ccc79892..b8ae00b6ab60 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -200,7 +200,6 @@ def eager_attention_forward( return attn_output, attn_weights -@use_kernel_func_from_hub("rotary_fn") class Dots1Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -230,7 +229,6 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -249,7 +247,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 62c814742b3d..2f113785a94e 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -20,7 +20,6 @@ import torch from ...cache_utils import Cache -from ...integrations import use_kernel_func_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -60,7 +59,6 @@ class Qwen3RotaryEmbedding(Qwen2RotaryEmbedding): pass -@use_kernel_func_from_hub("rotary_fn") class Qwen3Attention(LlamaAttention): def __init__(self, config: Qwen3Config, layer_idx: int): self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None @@ -68,7 +66,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -87,7 +84,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 529e897e9e7b..477694d5fb2b 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -120,7 +120,6 @@ def eager_attention_forward( return attn_output, attn_weights -@use_kernel_func_from_hub("rotary_fn") class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -149,7 +148,6 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -168,7 +166,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 2bdfd51134d8..362c8fab007f 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -30,7 +30,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -348,7 +347,6 @@ def eager_attention_forward( return attn_output, attn_weights -@use_kernel_func_from_hub("rotary_fn") class Qwen3NextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -377,7 +375,6 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.k_norm = Qwen3NextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -401,7 +398,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 40b30e929193..7deedb9c868b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -227,7 +227,6 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) del self.sliding_window - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -251,7 +250,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 60f4466e0a8d..1be0487cea98 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -35,7 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1442,7 +1442,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeThinkerTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1475,7 +1474,6 @@ def __init__(self, config, layer_idx): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape self.sliding_window = None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1494,7 +1492,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2324,7 +2322,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -2356,7 +2353,6 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -2375,7 +2371,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3355,7 +3351,6 @@ def forward(self, hidden_states): return hidden_states -@use_kernel_func_from_hub("rotary_fn") class Qwen3OmniMoeCode2WavAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -3385,7 +3380,6 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.sliding_window = config.sliding_window - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3404,7 +3398,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 5584c71b56c8..aab768b1cf1c 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -412,7 +412,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -@use_kernel_func_from_hub("rotary_fn") class Qwen3VLTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -443,7 +442,6 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.k_norm = Qwen3VLTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -462,7 +460,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 4bc849dfb0f0..60253ce21551 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -405,7 +405,6 @@ class Qwen3VLTextAttention(Qwen3Attention): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.sliding_window - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -424,7 +423,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 7a54affaa492..eab677bce4fe 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -225,11 +225,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -@use_kernel_func_from_hub("rotary_fn") class Qwen3VLMoeTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config, layer_idx: int): + def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): super().__init__() self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.config = config @@ -258,7 +257,6 @@ def __init__(self, config, layer_idx: int): self.k_norm = Qwen3VLMoeTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -277,7 +275,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index d9b8e2f1d656..15e99f440baa 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -41,7 +41,6 @@ Qwen3VLTextAttention, Qwen3VLTextModel, Qwen3VLVisionModel, - apply_rotary_pos_emb, ) @@ -349,9 +348,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen3VLMoeTextAttention(Qwen3VLTextAttention): - def __init__(self, config, layer_idx: int): - super().__init__(config, layer_idx) - self.rotary_fn = apply_rotary_pos_emb + pass class Qwen3VLMoeTextDecoderLayer(Qwen3MoeDecoderLayer): From d916ef065d62f9c2d12468d1bc7ec49755312c25 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 27 Nov 2025 09:42:33 +0000 Subject: [PATCH 26/32] update Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 6b0346ea83df..6d5acf02ac64 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -34,7 +34,6 @@ get_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, - # use_kernel_forward_from_hub, use_kernel_func_from_hub, ) From 8670efec976a5ef38b9430af54c4b8f11c2104ec Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 27 Nov 2025 09:50:24 +0000 Subject: [PATCH 27/32] fix Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 6d5acf02ac64..c57a6cc2f11f 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -328,6 +328,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", From fe20bd5b8f4dd49cba6b4097980e782b724d5517 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 27 Nov 2025 14:38:24 +0000 Subject: [PATCH 28/32] add compatibility for old version `kernels` Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 56 +++++++++++++++++--- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index c57a6cc2f11f..bada38965aed 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -28,14 +28,29 @@ try: from kernels import ( Device, - FuncRepository, LayerRepository, Mode, get_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, - use_kernel_func_from_hub, ) + from kernels import ( + use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub, + ) + + # Try to import FuncRepository, fallback if not available + try: + from kernels import FuncRepository + except ImportError: + FuncRepository = None + + # Try to import use_kernel_func_from_hub, fallback if not available + try: + from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub + + _has_use_kernel_func_from_hub = True + except ImportError: + _has_use_kernel_func_from_hub = False _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper() _kernels_available = True @@ -43,8 +58,6 @@ def use_kernel_forward_from_hub(layer_name: str): if _kernels_enabled: - from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub - return _kernels_use_kernel_forward_from_hub(layer_name) else: logger.warning_once( @@ -52,6 +65,21 @@ def use_kernel_forward_from_hub(layer_name: str): ) return lambda cls: cls + def use_kernel_func_from_hub(func_name: str): + if _kernels_enabled and _has_use_kernel_func_from_hub: + return _kernels_use_kernel_func_from_hub(func_name) + else: + if not _has_use_kernel_func_from_hub: + logger.warning_once( + "use_kernel_func_from_hub is not available in the installed kernels version. " + "Please upgrade kernels to use this feature." + ) + else: + logger.warning_once( + f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" + ) + return lambda func: func + _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = { "MultiScaleDeformableAttention": { "cuda": LayerRepository( @@ -164,14 +192,17 @@ def use_kernel_forward_from_hub(layer_name: str): ) } }, - "rotary_pos_emb": { + } + + # Add function kernel mappings if FuncRepository is available + if FuncRepository is not None: + _KERNEL_MAPPING["rotary_pos_emb"] = { "xpu": { Mode.INFERENCE: FuncRepository( repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" ) } - }, - } + } def has_key(d, key): return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values()) @@ -198,6 +229,12 @@ def decorator(cls): return decorator + def use_kernel_func_from_hub(*args, **kwargs): + def decorator(func): + return func + + return decorator + class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") @@ -210,6 +247,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + def register_kernel_mapping_transformers(*args, **kwargs): + raise RuntimeError( + "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`." + ) + _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, From 898e36ede2aadc37c685f08dd299a435acdae84a Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 28 Nov 2025 14:29:25 +0000 Subject: [PATCH 29/32] add rotary fn kernel to all models Signed-off-by: Liu, Kaixuan --- .../models/apertus/modeling_apertus.py | 4 +- .../models/arcee/modeling_arcee.py | 6 ++- src/transformers/models/aria/modeling_aria.py | 6 ++- .../models/bamba/modeling_bamba.py | 3 +- .../models/bitnet/modeling_bitnet.py | 4 +- .../models/cohere/modeling_cohere.py | 1 + src/transformers/models/csm/modeling_csm.py | 6 ++- src/transformers/models/cwm/modeling_cwm.py | 4 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 + .../deepseek_v3/modeling_deepseek_v3.py | 3 +- src/transformers/models/dia/modeling_dia.py | 36 +--------------- .../models/diffllama/modeling_diffllama.py | 3 +- src/transformers/models/doge/modeling_doge.py | 3 +- .../models/dots1/modeling_dots1.py | 4 +- src/transformers/models/emu3/modeling_emu3.py | 6 ++- .../models/ernie4_5/modeling_ernie4_5.py | 3 +- .../ernie4_5_moe/modeling_ernie4_5_moe.py | 3 +- .../models/evolla/modeling_evolla.py | 6 ++- .../models/exaone4/modeling_exaone4.py | 3 +- .../models/falcon_h1/modeling_falcon_h1.py | 4 +- .../models/flex_olmo/modeling_flex_olmo.py | 1 + .../models/gemma/modeling_gemma.py | 5 ++- .../models/gemma2/modeling_gemma2.py | 3 ++ .../models/gemma3/modeling_gemma3.py | 3 ++ .../models/gemma3n/modeling_gemma3n.py | 2 + src/transformers/models/glm/modeling_glm.py | 3 +- src/transformers/models/glm4/modeling_glm4.py | 3 +- .../models/glm4v_moe/modeling_glm4v_moe.py | 43 +++++++++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 1 + .../models/granite/modeling_granite.py | 6 ++- .../models/granitemoe/modeling_granitemoe.py | 6 ++- .../modeling_granitemoehybrid.py | 38 +++++++++++++++- .../modeling_granitemoeshared.py | 6 ++- .../models/helium/modeling_helium.py | 3 +- .../modeling_hunyuan_v1_dense.py | 4 +- .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 4 +- .../models/jamba/modeling_jamba.py | 38 +++++++++++++++- .../models/jetmoe/modeling_jetmoe.py | 3 +- src/transformers/models/lfm2/modeling_lfm2.py | 4 +- .../models/lfm2_moe/modeling_lfm2_moe.py | 4 +- .../models/lightglue/modeling_lightglue.py | 1 + .../models/llama/modeling_llama.py | 6 ++- .../models/minimax/modeling_minimax.py | 4 +- .../models/ministral/modeling_ministral.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mixtral/modeling_mixtral.py | 4 +- .../models/modernbert/modeling_modernbert.py | 2 + .../modeling_modernbert_decoder.py | 2 + .../models/moonshine/modeling_moonshine.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/modeling_olmo2.py | 1 + .../models/olmo3/modeling_olmo3.py | 1 + .../models/olmoe/modeling_olmoe.py | 4 +- .../models/parakeet/modeling_parakeet.py | 37 ++++++++++++++++ src/transformers/models/phi/modeling_phi.py | 3 ++ .../models/phimoe/modeling_phimoe.py | 6 ++- .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 6 ++- .../models/qwen3/modeling_qwen3.py | 4 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 4 +- .../models/qwen3_next/modeling_qwen3_next.py | 1 + .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 ++- .../models/qwen3_vl/modeling_qwen3_vl.py | 4 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 4 +- .../models/seed_oss/modeling_seed_oss.py | 3 +- .../models/smollm3/modeling_smollm3.py | 4 +- .../models/starcoder2/modeling_starcoder2.py | 3 ++ .../models/t5gemma/modeling_t5gemma.py | 4 ++ .../models/vaultgemma/modeling_vaultgemma.py | 3 ++ .../models/zamba2/modeling_zamba2.py | 2 + 70 files changed, 329 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 77a3d65478d6..1780b6383cc4 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -147,6 +147,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 779f4a63e378..15a10957c7f3 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -154,6 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -262,7 +264,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 96a6a82da91d..c62becaf93fb 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -378,6 +378,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -486,7 +488,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 2428222e0dbe..7a5ef3370088 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -388,7 +389,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 73597dd98d82..25568ca92365 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -85,6 +85,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9fc2593d3175..752b693fe7b8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 87da76281717..80a4b3d1c6fe 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -32,7 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -206,6 +206,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -314,7 +316,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index df9760ed1ba7..53cc8c23a8c6 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -113,6 +113,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int): self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index ddf5fce4dfce..92df09947bc4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -112,6 +113,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index cfd8d91dfb9a..4c56277c69dd 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -16,7 +16,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -253,6 +253,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 3a0ddf6e3f90..a29eda2ff74f 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -193,40 +193,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -303,7 +269,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 99524915b9f6..b23ee2ed948d 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -32,7 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_layers import ( @@ -141,6 +141,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index b9ebf9856264..a3a94f88df55 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -33,7 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...integrations.flex_attention import compile_friendly_flex_attention from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -143,6 +143,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index b8ae00b6ab60..9092a3533e43 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -135,6 +135,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -226,6 +227,7 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 65671913b27f..a0eafe4db0dc 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -33,7 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -52,6 +52,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -142,6 +143,7 @@ def __init__(self, config: Emu3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -160,7 +162,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index b53ddf923e70..4f233c456b14 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -221,6 +221,7 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -239,7 +240,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index ccd05fe26347..63154a034350 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -244,6 +244,7 @@ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -262,7 +263,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index f4d0ce11255f..7a05cb0a4062 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -1051,6 +1051,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1115,6 +1116,7 @@ def __init__(self, config: EvollaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1133,7 +1135,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index cb70c9cff142..2287f9f8e88d 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -140,6 +140,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index a6fd7a5aba99..1cf19c500737 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -36,7 +36,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -295,6 +295,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -385,6 +386,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.key_multiplier = config.key_multiplier def forward( diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index b948b420ad63..993a3dae1652 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -241,6 +241,7 @@ def __init__(self, config: FlexOlmoConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8834ba8c3564..d3280b45834d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -152,6 +153,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -242,6 +244,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -260,7 +263,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 69e486032107..d6362634b1e3 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -156,6 +157,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -256,6 +258,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 8e93ef9231b5..253931d4eeb1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -230,6 +231,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -330,6 +332,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.is_sliding = self.layer_type == "sliding_attention" diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3be6ad39ddca..3807311ff674 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1254,6 +1254,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.is_sliding = self.layer_type == "sliding_attention" @@ -1475,6 +1476,7 @@ def __init__(self, config: Gemma3nConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 8a508e2de54c..07ca34658c5e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -239,6 +239,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -257,7 +258,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index c982c36f9aab..d6d7fe16cd59 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -221,6 +221,7 @@ def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -239,7 +240,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 9537d9018838..74016aa98e7e 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -204,6 +204,48 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -280,6 +322,7 @@ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.rope_parameters = config.rope_parameters def forward( diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 8e1ce9df0b97..5e1173d823d0 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -325,6 +325,7 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 42de2e0724f3..6e7969548aa7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -50,6 +50,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -140,6 +141,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -158,7 +160,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index f722ad416a2f..44eb6d705b2f 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -272,6 +272,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -362,6 +363,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -380,7 +382,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c9e7245956f3..9534b45c0624 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -31,7 +31,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -59,6 +59,41 @@ logger = logging.get_logger(__name__) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -122,6 +157,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 606a59390e6e..df9af83e23a9 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -262,6 +262,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -352,6 +353,7 @@ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -370,7 +372,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index e616da3cd07b..c075c5ed06a9 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -243,6 +243,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -261,7 +262,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 4d184a0b1982..28899e306f4c 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -87,6 +87,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -177,6 +178,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 281a50a9e2cc..dda4366f0d4d 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -86,6 +86,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -176,6 +177,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fbda366fc319..47673a3afab2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,7 +32,7 @@ from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -175,6 +175,41 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.key_cache[layer_idx].shape[-2] +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -229,6 +264,7 @@ def __init__(self, config: JambaConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index b102a111e10f..1ef99fc46d7f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -366,6 +366,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index f1d639d16bbd..aa66a2e2d80c 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -26,7 +26,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -292,6 +292,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -371,6 +372,7 @@ def __init__(self, config: Lfm2Config, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 73b9c4a8fde0..55dcde8b07e9 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast @@ -363,6 +363,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -442,6 +443,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index 0d5044c5a40a..e684b11130d8 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -199,6 +199,7 @@ def __init__(self, config: LightGlueConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e3adac5d117d..5aa004cee8c7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -142,6 +142,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -248,6 +249,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -266,7 +268,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 004ed68cef23..8b8f8c9adb3a 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -326,6 +326,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -407,6 +408,7 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index b1c8555fd96b..ccf6c9abcda5 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -13,7 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -54,6 +54,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -137,6 +138,7 @@ def __init__(self, config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 60c7e2d49eed..43dc5c18a90b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -15,7 +15,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -55,6 +55,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -136,6 +137,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1faff1f4dcea..95b236dadce6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -37,7 +37,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -225,6 +225,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -306,6 +307,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 8069f2bec2ff..4869da818e01 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -31,6 +31,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations import use_kernel_func_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -331,6 +332,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 7564e375716b..5059e623c9dd 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast @@ -183,6 +184,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 373e1db4a217..87893284ec4e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -264,6 +264,7 @@ def __init__( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb # Pad head dimension to the next specified multiple. if self.config.pad_head_dim_to_multiple_of is not None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2ba7a25f71b5..1dcaa6247155 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -237,6 +237,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 44e3592157af..aa6296f344f1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -230,6 +230,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index d49570982f48..003e691e9c82 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -161,6 +161,7 @@ def __init__(self, config: Olmo3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) assert config.layer_types is not None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index f078518e0c1f..d04cd421d441 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -148,6 +148,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -238,6 +239,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.k_norm = OlmoeRMSNorm( (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 079307547fe5..d11d56d3a11c 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,6 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations import use_kernel_func_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -182,6 +183,41 @@ def forward(self, hidden_states, attention_mask=None): return hidden_states.transpose(1, 2) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -245,6 +281,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb # W_{k,R} projection self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) # global content bias diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4a1530b78564..eec8794a3d65 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -13,6 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -105,6 +106,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -185,6 +187,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.rotary_fn = apply_rotary_pos_emb self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.qk_layernorm = config.qk_layernorm diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 12e41214094d..0b71816b0c86 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -128,6 +128,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -218,6 +219,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -236,7 +238,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1215f3677603..060683cfb972 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -13,7 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -119,6 +119,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -201,6 +202,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8bda140d3cdb..8ea900211aae 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -35,7 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -161,6 +161,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -243,6 +244,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb if self.config.layer_types[layer_idx] == "sliding_attention": self.sliding_window = config.sliding_window @@ -263,7 +265,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 2fb330dca4b7..9b6c41662929 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -247,10 +247,10 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -269,7 +269,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 477694d5fb2b..caf7e26a39fe 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -55,6 +55,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -145,6 +146,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 362c8fab007f..27581014e9f7 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -371,6 +371,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3NextRMSNorm( self.head_dim, eps=config.rms_norm_eps diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1be0487cea98..73709df5d0bc 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -35,7 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1415,6 +1415,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1467,6 +1468,7 @@ def __init__(self, config, layer_idx): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # unlike olmo, only on the head dim! @@ -2348,6 +2350,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3OmniMoeRMSNorm( self.head_dim, eps=config.rms_norm_eps @@ -3377,6 +3380,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.sliding_window = config.sliding_window diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index aab768b1cf1c..9922ccb09bb6 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -385,6 +385,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -438,6 +439,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3VLTextRMSNorm( self.head_dim, eps=config.rms_norm_eps diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index eab677bce4fe..28e2c85f156c 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -198,6 +198,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -251,6 +252,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3VLMoeTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # unlike olmo, only on the head dim! diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 682193ca8d51..877d5eaa4fc0 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -90,6 +90,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index e23d4993e84c..c59906f63692 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -118,6 +118,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -208,6 +209,7 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.use_rope = config.no_rope_layers[layer_idx] self.sliding_window = ( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 042033fe3565..c013e97f2169 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -35,6 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -74,6 +75,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -155,6 +157,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb self.residual_dropout = config.residual_dropout def forward( diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index ac9b64929280..5c0410655ee4 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -162,6 +163,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -263,6 +265,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None @@ -338,6 +341,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping if config.cross_attention_hidden_size is None: diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index ee36d7519a53..ad93f819fa0e 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -87,6 +88,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -187,6 +189,7 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5b5f532cebf7..9ade9d1dfc8d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -316,6 +317,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. From b8b68c7caf466257b88ef910efa18d0812c33597 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 28 Nov 2025 15:31:40 +0000 Subject: [PATCH 30/32] update modular part Signed-off-by: Liu, Kaixuan --- .../models/apertus/modeling_apertus.py | 2 +- .../models/apertus/modular_apertus.py | 3 +- .../models/bamba/modeling_bamba.py | 2 +- .../models/bamba/modular_bamba.py | 4 +- .../models/bitnet/modeling_bitnet.py | 2 +- .../models/bitnet/modular_bitnet.py | 3 +- .../models/cohere/modeling_cohere.py | 2 +- .../models/cohere/modular_cohere.py | 3 +- .../models/cohere2/modeling_cohere2.py | 3 +- .../models/cohere2/modular_cohere2.py | 3 +- src/transformers/models/cwm/modeling_cwm.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- src/transformers/models/dbrx/modular_dbrx.py | 3 +- .../deepseek_v3/modeling_deepseek_v3.py | 3 +- .../models/deepseek_v3/modular_deepseek_v3.py | 3 +- .../models/diffllama/modeling_diffllama.py | 7 +-- .../models/diffllama/modular_diffllama.py | 7 +-- .../models/dinov3_vit/modeling_dinov3_vit.py | 3 +- .../models/dinov3_vit/modular_dinov3_vit.py | 3 +- src/transformers/models/doge/modeling_doge.py | 3 +- src/transformers/models/doge/modular_doge.py | 3 +- .../models/dots1/modeling_dots1.py | 2 +- .../models/ernie4_5/modular_ernie4_5.py | 1 + .../models/exaone4/modeling_exaone4.py | 3 +- .../models/exaone4/modular_exaone4.py | 3 +- .../models/falcon_h1/modeling_falcon_h1.py | 2 +- .../models/falcon_h1/modular_falcon_h1.py | 3 +- .../models/flex_olmo/modeling_flex_olmo.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/gemma2/modular_gemma2.py | 3 +- .../models/gemma3/modeling_gemma3.py | 2 +- .../models/gemma3/modular_gemma3.py | 3 +- .../models/gemma3n/modeling_gemma3n.py | 6 +-- .../models/gemma3n/modular_gemma3n.py | 5 +- src/transformers/models/glm/modular_glm.py | 1 + .../models/glm4_moe/modeling_glm4_moe.py | 47 +------------------ .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../models/gpt_neox/modular_gpt_neox.py | 3 +- .../models/gpt_oss/modeling_gpt_oss.py | 2 +- .../models/gpt_oss/modular_gpt_oss.py | 3 +- .../models/helium/modular_helium.py | 1 + .../modeling_hunyuan_v1_dense.py | 2 +- .../modular_hunyuan_v1_dense.py | 3 +- .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 2 +- .../hunyuan_v1_moe/modular_hunyuan_v1_moe.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/jetmoe/modular_jetmoe.py | 3 +- src/transformers/models/lfm2/modeling_lfm2.py | 2 +- src/transformers/models/lfm2/modular_lfm2.py | 3 +- .../models/lfm2_moe/modeling_lfm2_moe.py | 2 +- .../models/lightglue/modeling_lightglue.py | 18 ++++--- .../models/lightglue/modular_lightglue.py | 3 +- .../longcat_flash/modeling_longcat_flash.py | 31 +++++++++++- .../models/minimax/modeling_minimax.py | 2 +- .../models/ministral/modeling_ministral.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mistral/modular_mistral.py | 3 +- .../models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/mlcd/modeling_mlcd.py | 3 +- src/transformers/models/mlcd/modular_mlcd.py | 3 +- .../models/modernbert/modeling_modernbert.py | 5 +- .../models/modernbert/modular_modernbert.py | 5 +- .../modeling_modernbert_decoder.py | 3 +- .../modular_modernbert_decoder.py | 3 +- .../models/moonshine/modeling_moonshine.py | 2 +- .../models/moonshine/modular_moonshine.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo/modular_olmo.py | 6 ++- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/olmo2/modular_olmo2.py | 3 +- .../models/olmo3/modeling_olmo3.py | 2 +- .../models/olmo3/modular_olmo3.py | 3 +- .../models/olmoe/modeling_olmoe.py | 2 +- .../models/olmoe/modular_olmoe.py | 3 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi/modular_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- src/transformers/models/phi3/modular_phi3.py | 3 +- .../modeling_phi4_multimodal.py | 3 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2/modular_qwen2.py | 3 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 8 ++-- .../qwen2_5_omni/modular_qwen2_5_omni.py | 8 ++-- .../models/qwen3/modeling_qwen3.py | 2 +- .../models/qwen3/modular_qwen3.py | 3 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 2 +- .../models/qwen3_next/modeling_qwen3_next.py | 2 +- .../models/qwen3_next/modular_qwen3_next.py | 3 +- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 +-- .../models/qwen3_vl/modeling_qwen3_vl.py | 2 +- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- .../models/seed_oss/modeling_seed_oss.py | 3 +- .../models/seed_oss/modular_seed_oss.py | 3 +- .../models/smollm3/modeling_smollm3.py | 2 +- .../models/smollm3/modular_smollm3.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- .../models/starcoder2/modular_starcoder2.py | 3 +- .../models/t5gemma/modeling_t5gemma.py | 2 +- .../models/vaultgemma/modeling_vaultgemma.py | 2 +- .../video_llama_3/modeling_video_llama_3.py | 3 +- .../video_llama_3/modular_video_llama_3.py | 3 +- .../models/zamba2/modeling_zamba2.py | 3 +- .../models/zamba2/modular_zamba2.py | 3 +- 104 files changed, 217 insertions(+), 175 deletions(-) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 1780b6383cc4..61f7601f7f63 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -261,7 +261,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index a60daa2f8194..d75a9137d09f 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -34,7 +34,6 @@ LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding, - apply_rotary_pos_emb, eager_attention_forward, ) from ..nemotron.modeling_nemotron import NemotronMLP @@ -225,7 +224,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 7a5ef3370088..a0aae79d3899 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -348,7 +348,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BambaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: BambaConfig, layer_idx: int): + def __init__(self, *args, **kwargs): super().__init__() self.config = config self.layer_idx = layer_idx diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index d29273940b8a..f31518bd6ed5 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -198,7 +198,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BambaAttention(LlamaAttention): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.rotary_fn = apply_rotary_pos_emb class BambaRMSNormGated(MambaRMSNormGated): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 25568ca92365..eb2807657950 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -196,7 +196,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index 093eb2428395..2ca747807afa 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -58,6 +58,7 @@ class BitNetAttention(LlamaAttention): def __init__(self, config: BitNetConfig, layer_idx: int): super().__init__(config, layer_idx) self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -76,7 +77,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 752b693fe7b8..556b007ce320 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -283,7 +283,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 5147e5638eb2..acb44b9dbff2 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -144,6 +144,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.k_norm = CohereLayerNorm( hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -170,7 +171,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index a9c56cd2491c..5e0423c2875d 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -225,6 +225,7 @@ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -244,7 +245,7 @@ def forward( cos, sin = position_embeddings if self.sliding_window is not None: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 9e8bc6b564e4..afc3f0e8cdbb 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -270,6 +270,7 @@ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -289,7 +290,7 @@ def forward( cos, sin = position_embeddings if self.sliding_window is not None: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 53cc8c23a8c6..39ae259fe0a7 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -216,7 +216,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 92df09947bc4..faee15b60f30 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -209,6 +209,7 @@ def __init__( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -240,7 +241,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index 42a9079cb012..7edffa0a70a6 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -77,6 +77,7 @@ def __init__( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -108,7 +109,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 4c56277c69dd..2d52144add58 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -406,6 +406,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) + self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": @@ -447,7 +448,7 @@ def forward( if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: - q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + q_rot, k_rot = self.rotary_fn(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 97c3f0cd425b..6f744389bc35 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -208,6 +208,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) + self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": @@ -249,7 +250,7 @@ def forward( if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: - q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + q_rot, k_rot = self.rotary_fn(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index b23ee2ed948d..52210d35135a 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -220,6 +220,7 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -244,7 +245,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -329,7 +330,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -454,7 +455,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 904881ed7fbd..13a763f54fdd 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -95,6 +95,7 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -119,7 +120,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -204,7 +205,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -329,7 +330,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index 09edeed17543..ad0cba3f689a 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -274,6 +274,7 @@ def __init__(self, config: DINOv3ViTConfig): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -295,7 +296,7 @@ def forward( value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index 7ae7d1632053..85bf818c1bf9 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -231,6 +231,7 @@ def __init__(self, config: DINOv3ViTConfig): self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -252,7 +253,7 @@ def forward( value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index a3a94f88df55..303808e50d20 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -289,6 +289,7 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): ) self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -308,7 +309,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index eacea60cf442..629273fe87f2 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -325,6 +325,7 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): ) self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -344,7 +345,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 9092a3533e43..8a1cdfd06fed 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -249,7 +249,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5/modular_ernie4_5.py b/src/transformers/models/ernie4_5/modular_ernie4_5.py index 780b07164ec0..3d41f86c951d 100644 --- a/src/transformers/models/ernie4_5/modular_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modular_ernie4_5.py @@ -101,6 +101,7 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb class Ernie4_5ForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index 2287f9f8e88d..f6e8c7966ce2 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -231,6 +231,7 @@ def __init__(self, config: Exaone4Config, layer_idx: int): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -256,7 +257,7 @@ def forward( cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = { diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index d6004db0d28c..333a966592ba 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -258,6 +258,7 @@ def __init__(self, config: Exaone4Config, layer_idx: int): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -283,7 +284,7 @@ def forward( cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = { diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 1cf19c500737..d3effa1d6915 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -406,7 +406,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 386266b89bf3..f191df88dfdb 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -205,6 +205,7 @@ class FalconH1Attention(LlamaAttention): def __init__(self, config: FalconH1Config, layer_idx: int): super().__init__(config, layer_idx) self.key_multiplier = config.key_multiplier + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -223,7 +224,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 993a3dae1652..b4de3e67fffb 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -267,7 +267,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index d6362634b1e3..956fc3e6f8e0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -279,7 +279,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f33a0c7ba666..5058d39f9335 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -313,6 +313,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.is_causal = not getattr(config, "use_bidirectional_attention", False) self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -331,7 +332,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 253931d4eeb1..e59116df9f37 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -360,7 +360,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 7f9e987035ad..5c499da4b323 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -450,6 +450,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -471,7 +472,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3807311ff674..59fea8fcbb31 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1291,7 +1291,7 @@ def forward( cos, sin = position_embeddings query_states = self.q_proj(hidden_states).view(hidden_shape) query_states = self.q_norm(query_states) - query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = self.rotary_fn(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer @@ -1303,7 +1303,7 @@ def forward( else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) - key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = self.rotary_fn(key_states, cos, sin, unsqueeze_dim=2) key_states = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape) @@ -1497,7 +1497,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a226824781cf..a4b7aeca8777 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1720,6 +1720,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( config.layer_types[layer_idx] ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1736,7 +1737,7 @@ def forward( cos, sin = position_embeddings query_states = self.q_proj(hidden_states).view(hidden_shape) query_states = self.q_norm(query_states) - query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = self.rotary_fn(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer @@ -1748,7 +1749,7 @@ def forward( else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) - key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = self.rotary_fn(key_states, cos, sin, unsqueeze_dim=2) key_states = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 059cb296c972..76532939ec7f 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -126,6 +126,7 @@ class GlmAttention(LlamaAttention): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb class GlmForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 84e6dd3bd77d..a914f44051c7 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -148,51 +148,6 @@ def eager_attention_forward( return attn_output, attn_weights -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - # Keep half or full tensor for later concatenation - rotary_dim = cos.shape[-1] - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - return q_embed, k_embed - - class Glm4MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -247,7 +202,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fc7d6fd40a80..4755872ebcab 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -201,6 +201,7 @@ def __init__(self, config, layer_idx=None): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -218,7 +219,7 @@ def forward( query_states, key_states, value_states = qkv.chunk(3, dim=-1) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) # Cache QKV values if layer_past is not None: diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index c267753db350..bc7d222fda5b 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -153,6 +153,7 @@ def __init__(self, config, layer_idx=None): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -170,7 +171,7 @@ def forward( query_states, key_states, value_states = qkv.chunk(3, dim=-1) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) # Cache QKV values if layer_past is not None: diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e1173d823d0..745b06ae56a2 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -347,7 +347,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 57acfea8df64..a0767155415a 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -254,6 +254,7 @@ def __init__(self, config: GptOssConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -273,7 +274,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/helium/modular_helium.py b/src/transformers/models/helium/modular_helium.py index 79d995720e30..b99477992385 100644 --- a/src/transformers/models/helium/modular_helium.py +++ b/src/transformers/models/helium/modular_helium.py @@ -99,6 +99,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.scaling = 1 / math.sqrt(self.head_dim) + self.rotary_fn = apply_rotary_pos_emb class HeliumDecoderLayer(LlamaDecoderLayer): diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 28899e306f4c..66709468c089 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -199,7 +199,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index d41b5f236759..2462ab746e4a 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -66,6 +66,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): super().__init__(config, layer_idx) self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -84,7 +85,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index dda4366f0d4d..7986b4a92daa 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -198,7 +198,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index f94c4ed4c2c7..2abf15e24649 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -36,7 +36,6 @@ LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, - apply_rotary_pos_emb, eager_attention_forward, ) from ..mixtral.modeling_mixtral import MixtralExperts @@ -81,7 +80,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 1ef99fc46d7f..d21f4159e285 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -469,6 +469,7 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.experts = JetMoeMoA(config) self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -490,7 +491,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/jetmoe/modular_jetmoe.py b/src/transformers/models/jetmoe/modular_jetmoe.py index db8c3e1059c0..94dd2d508300 100644 --- a/src/transformers/models/jetmoe/modular_jetmoe.py +++ b/src/transformers/models/jetmoe/modular_jetmoe.py @@ -324,6 +324,7 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.experts = JetMoeMoA(config) self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -345,7 +346,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index aa66a2e2d80c..2d6c726672cd 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -395,7 +395,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 0075280e6ddb..3a12d12b004b 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -33,7 +33,6 @@ LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, - apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_lfm2 import Lfm2Config @@ -244,7 +243,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 55dcde8b07e9..e4d54cd78abe 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -466,7 +466,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index e684b11130d8..83adcea63a6d 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -27,6 +27,7 @@ from torch.nn.utils.rnn import pad_sequence from ...activations import ACT2FN +from ...integrations import use_kernel_func_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -99,13 +100,13 @@ def forward( def rotate_half(x): - # Split and rotate. Note that this function is different from e.g. Llama. - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -126,14 +127,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - dtype = q.dtype - q = q.float() - k = k.float() cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -224,7 +222,7 @@ def forward( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index f61e86a67e0d..5bd7b30ee324 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -29,7 +29,6 @@ from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModelForKeypointDetection from ..clip.modeling_clip import CLIPMLP -from ..cohere.modeling_cohere import apply_rotary_pos_emb from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from ..superglue.image_processing_superglue import ( SuperGlueImageProcessor, @@ -284,7 +283,7 @@ def forward( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 4135bce33d83..8c638168c953 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -251,6 +251,34 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -376,6 +404,7 @@ def __init__(self, config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) + self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 8b8f8c9adb3a..a4f1a5e7f22c 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -427,7 +427,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index ccf6c9abcda5..75e200ddf5d6 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -158,7 +158,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 43dc5c18a90b..c97fd5b6880a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -156,7 +156,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 709ff855c399..215499d577a9 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -50,6 +50,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -68,7 +69,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 95b236dadce6..c693e6f4d8f9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -326,7 +326,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index 72e26db9bd1c..2379222e696a 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -254,6 +254,7 @@ def __init__(self, config: MLCDVisionConfig): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.num_key_value_groups = config.num_key_value_groups + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -273,7 +274,7 @@ def forward( # Apply positional embeddings cos = position_embeddings[0].unsqueeze(0).float() sin = position_embeddings[1].unsqueeze(0).float() - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) # Each of shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.permute(0, 2, 1, 3).contiguous() diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 4cfc743948ef..f9b3b5a73d77 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -201,6 +201,7 @@ def __init__(self, config: MLCDVisionConfig): super().__init__(config) self.num_key_value_groups = config.num_key_value_groups self.is_causal = False + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -219,7 +220,7 @@ def forward( # Apply positional embeddings cos = position_embeddings[0].unsqueeze(0).float() sin = position_embeddings[1].unsqueeze(0).float() - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) # Each of shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.permute(0, 2, 1, 3).contiguous() diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 4869da818e01..42a744f4ae28 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -377,7 +377,7 @@ def eager_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = apply_rotary_pos_emb(query, key, cos, sin) + query, key = module.rotary_fn(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale @@ -457,7 +457,7 @@ def sdpa_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = apply_rotary_pos_emb(query, key, cos, sin) + query, key = module.rotary_fn(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask @@ -532,6 +532,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 8fc375671583..2cde78b24962 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -556,7 +556,7 @@ def eager_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = apply_rotary_pos_emb(query, key, cos, sin) + query, key = module.rotary_fn(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale @@ -636,7 +636,7 @@ def sdpa_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = apply_rotary_pos_emb(query, key, cos, sin) + query, key = module.rotary_fn(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask @@ -711,6 +711,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 5059e623c9dd..026de89b7202 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -274,6 +274,7 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N self.out_drop = nn.Dropout(config.attention_dropout) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -292,7 +293,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index b39cb440c3e9..a21ca2f61ccb 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -311,6 +311,7 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N self.out_drop = nn.Dropout(config.attention_dropout) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -329,7 +330,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 87893284ec4e..47e178430542 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -323,7 +323,7 @@ def forward( if not is_cross_attention: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 5afbdb2db363..9ea67f52b737 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -274,6 +274,7 @@ def __init__( self.head_dim_padding = target_head_dim - self.head_dim else: self.head_dim_padding = 0 + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -324,7 +325,7 @@ def forward( if not is_cross_attention: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1dcaa6247155..b56add044f3a 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -266,7 +266,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 4f2796837cb5..4c6b2f7eeb15 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -114,6 +114,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class OlmoAttention(LlamaAttention): + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.rotary_fn = apply_rotary_pos_emb + def forward( self, hidden_states: torch.Tensor, @@ -141,7 +145,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index aa6296f344f1..c2e4071c5f3d 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -256,7 +256,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 152a78c4e462..2b989e0c3841 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -211,6 +211,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -234,7 +235,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index 003e691e9c82..cc5369cceb51 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -189,7 +189,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo3/modular_olmo3.py b/src/transformers/models/olmo3/modular_olmo3.py index 265ef3bae967..d022ac3c2c2f 100644 --- a/src/transformers/models/olmo3/modular_olmo3.py +++ b/src/transformers/models/olmo3/modular_olmo3.py @@ -214,6 +214,7 @@ def __init__(self, config: Olmo3Config, layer_idx: int): assert config.layer_types is not None self.attention_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -236,7 +237,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index d04cd421d441..6d929138f613 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -270,7 +270,7 @@ def forward( key_states = key_states.view(*hidden_shape).transpose(1, 2) value_states = value_states.view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index eef444e6f24a..c6a8eb1fe7d7 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -63,6 +63,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): self.k_norm = OlmoeRMSNorm( (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -89,7 +90,7 @@ def forward( key_states = key_states.view(*hidden_shape).transpose(1, 2) value_states = value_states.view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index eec8794a3d65..ac7638e8741e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -231,7 +231,7 @@ def forward( key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot, key_rot = self.rotary_fn(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 3ecc9ba9d4f7..8f10a6f4af3e 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -84,6 +84,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.k_layernorm = nn.LayerNorm( config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -117,7 +118,7 @@ def forward( key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot, key_rot = self.rotary_fn(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 29b3d2847ed1..f7ec566ef491 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -226,6 +226,7 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -250,7 +251,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 5b0d5f76d69c..dee9daa0135c 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -118,6 +118,7 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -142,7 +143,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index eab15068d252..070d47efdcff 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1259,6 +1259,7 @@ def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1283,7 +1284,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 060683cfb972..28316fc97d33 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -222,7 +222,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index b06d3182b273..efb99a6381b2 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -58,6 +58,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -76,7 +77,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 0826873a8f98..f4503eaa527f 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -908,6 +908,7 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: self.config = config self.attention_dropout = 0.0 self.is_causal = False + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -920,8 +921,8 @@ def forward( query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = self.rotary_fn(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = self.rotary_fn(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) @@ -3033,6 +3034,7 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.to_v = nn.Linear(config.hidden_size, self.inner_dim) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3057,7 +3059,7 @@ def forward( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = self.rotary_fn(query[:, :1], key[:, :1], cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 685c6b5c86f0..b0f0b96f62da 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1840,6 +1840,7 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: self.config = config self.attention_dropout = 0.0 self.is_causal = False + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -1852,8 +1853,8 @@ def forward( query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = self.rotary_fn(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = self.rotary_fn(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) @@ -3207,6 +3208,7 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.to_v = nn.Linear(config.hidden_size, self.inner_dim) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3231,7 +3233,7 @@ def forward( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = self.rotary_fn(query[:, :1], key[:, :1], cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 9b6c41662929..65d5d04090aa 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -269,7 +269,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 2f113785a94e..c4b9ee5a7ed0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -36,7 +36,6 @@ Qwen2ForTokenClassification, Qwen2RMSNorm, Qwen2RotaryEmbedding, - apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_qwen3 import Qwen3Config @@ -84,7 +83,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index caf7e26a39fe..0d09deb61ef2 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -168,7 +168,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 27581014e9f7..f757ca3166c2 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -399,7 +399,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 7deedb9c868b..40b30e929193 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -227,6 +227,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) del self.sliding_window + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -250,7 +251,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 73709df5d0bc..7883e6afd4b9 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1494,7 +1494,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2374,7 +2374,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3402,7 +3402,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 9922ccb09bb6..7b697f5639f7 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -462,7 +462,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 60253ce21551..4bc849dfb0f0 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -405,6 +405,7 @@ class Qwen3VLTextAttention(Qwen3Attention): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.sliding_window + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -423,7 +424,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 28e2c85f156c..0b60dc4ba3f3 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -277,7 +277,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 877d5eaa4fc0..92597e7e91ca 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -183,6 +183,7 @@ def __init__(self, config: SeedOssConfig, layer_idx: int): ) self.residual_dropout = config.residual_dropout + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -201,7 +202,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/seed_oss/modular_seed_oss.py b/src/transformers/models/seed_oss/modular_seed_oss.py index 96cb4d428214..7f5a218e74dc 100644 --- a/src/transformers/models/seed_oss/modular_seed_oss.py +++ b/src/transformers/models/seed_oss/modular_seed_oss.py @@ -94,6 +94,7 @@ def __init__(self, config: SeedOssConfig, layer_idx: int): ) self.residual_dropout = config.residual_dropout + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -112,7 +113,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index c59906f63692..8819b361f846 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -236,7 +236,7 @@ def forward( if self.use_rope: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index 995c15f10de0..d0f6d08c6279 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -238,6 +238,7 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention" else None ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -257,7 +258,7 @@ def forward( if self.use_rope: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c013e97f2169..3510365ed95e 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -177,7 +177,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 2c0a27f81bdf..e6e048e18016 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -76,6 +76,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -94,7 +95,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 5c0410655ee4..437f172936df 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -286,7 +286,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index ad93f819fa0e..9a7c174a34a8 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -210,7 +210,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 8442a4a7c175..61e835f11b3f 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -200,6 +200,7 @@ def __init__(self, config): self.num_key_value_groups = 1 self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -223,7 +224,7 @@ def forward( value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 0bd05ec9acf8..fbc4d9259939 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -278,6 +278,7 @@ def __init__(self, config): self.attention_dropout = config.attention_dropout del self.scale del self.dropout + self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -301,7 +302,7 @@ def forward( value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 9ade9d1dfc8d..ef31502c40ec 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -416,6 +416,7 @@ def __init__( self.linear_v_adapter_list.append(linear_v_adapter) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -445,7 +446,7 @@ def forward( if self.config.use_mem_rope: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b4761a50dd42..c744bb1e3954 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -224,6 +224,7 @@ def __init__( self.linear_v_adapter_list.append(linear_v_adapter) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -253,7 +254,7 @@ def forward( if self.config.use_mem_rope: cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) From af25ce0eb0cfb70418870ef0a0ca536d0a62f4b5 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 28 Nov 2025 15:57:05 +0000 Subject: [PATCH 31/32] Revert "update modular part" This reverts commit b8b68c7caf466257b88ef910efa18d0812c33597. --- .../models/apertus/modeling_apertus.py | 2 +- .../models/apertus/modular_apertus.py | 3 +- .../models/bamba/modeling_bamba.py | 2 +- .../models/bamba/modular_bamba.py | 4 +- .../models/bitnet/modeling_bitnet.py | 2 +- .../models/bitnet/modular_bitnet.py | 3 +- .../models/cohere/modeling_cohere.py | 2 +- .../models/cohere/modular_cohere.py | 3 +- .../models/cohere2/modeling_cohere2.py | 3 +- .../models/cohere2/modular_cohere2.py | 3 +- src/transformers/models/cwm/modeling_cwm.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- src/transformers/models/dbrx/modular_dbrx.py | 3 +- .../deepseek_v3/modeling_deepseek_v3.py | 3 +- .../models/deepseek_v3/modular_deepseek_v3.py | 3 +- .../models/diffllama/modeling_diffllama.py | 7 ++- .../models/diffllama/modular_diffllama.py | 7 ++- .../models/dinov3_vit/modeling_dinov3_vit.py | 3 +- .../models/dinov3_vit/modular_dinov3_vit.py | 3 +- src/transformers/models/doge/modeling_doge.py | 3 +- src/transformers/models/doge/modular_doge.py | 3 +- .../models/dots1/modeling_dots1.py | 2 +- .../models/ernie4_5/modular_ernie4_5.py | 1 - .../models/exaone4/modeling_exaone4.py | 3 +- .../models/exaone4/modular_exaone4.py | 3 +- .../models/falcon_h1/modeling_falcon_h1.py | 2 +- .../models/falcon_h1/modular_falcon_h1.py | 3 +- .../models/flex_olmo/modeling_flex_olmo.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/gemma2/modular_gemma2.py | 3 +- .../models/gemma3/modeling_gemma3.py | 2 +- .../models/gemma3/modular_gemma3.py | 3 +- .../models/gemma3n/modeling_gemma3n.py | 6 +-- .../models/gemma3n/modular_gemma3n.py | 5 +- src/transformers/models/glm/modular_glm.py | 1 - .../models/glm4_moe/modeling_glm4_moe.py | 47 ++++++++++++++++++- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../models/gpt_neox/modular_gpt_neox.py | 3 +- .../models/gpt_oss/modeling_gpt_oss.py | 2 +- .../models/gpt_oss/modular_gpt_oss.py | 3 +- .../models/helium/modular_helium.py | 1 - .../modeling_hunyuan_v1_dense.py | 2 +- .../modular_hunyuan_v1_dense.py | 3 +- .../hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 2 +- .../hunyuan_v1_moe/modular_hunyuan_v1_moe.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/jetmoe/modular_jetmoe.py | 3 +- src/transformers/models/lfm2/modeling_lfm2.py | 2 +- src/transformers/models/lfm2/modular_lfm2.py | 3 +- .../models/lfm2_moe/modeling_lfm2_moe.py | 2 +- .../models/lightglue/modeling_lightglue.py | 18 +++---- .../models/lightglue/modular_lightglue.py | 3 +- .../longcat_flash/modeling_longcat_flash.py | 31 +----------- .../models/minimax/modeling_minimax.py | 2 +- .../models/ministral/modeling_ministral.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mistral/modular_mistral.py | 3 +- .../models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/mlcd/modeling_mlcd.py | 3 +- src/transformers/models/mlcd/modular_mlcd.py | 3 +- .../models/modernbert/modeling_modernbert.py | 5 +- .../models/modernbert/modular_modernbert.py | 5 +- .../modeling_modernbert_decoder.py | 3 +- .../modular_modernbert_decoder.py | 3 +- .../models/moonshine/modeling_moonshine.py | 2 +- .../models/moonshine/modular_moonshine.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo/modular_olmo.py | 6 +-- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/olmo2/modular_olmo2.py | 3 +- .../models/olmo3/modeling_olmo3.py | 2 +- .../models/olmo3/modular_olmo3.py | 3 +- .../models/olmoe/modeling_olmoe.py | 2 +- .../models/olmoe/modular_olmoe.py | 3 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi/modular_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- src/transformers/models/phi3/modular_phi3.py | 3 +- .../modeling_phi4_multimodal.py | 3 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2/modular_qwen2.py | 3 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 8 ++-- .../qwen2_5_omni/modular_qwen2_5_omni.py | 8 ++-- .../models/qwen3/modeling_qwen3.py | 2 +- .../models/qwen3/modular_qwen3.py | 3 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 2 +- .../models/qwen3_next/modeling_qwen3_next.py | 2 +- .../models/qwen3_next/modular_qwen3_next.py | 3 +- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 +-- .../models/qwen3_vl/modeling_qwen3_vl.py | 2 +- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- .../models/seed_oss/modeling_seed_oss.py | 3 +- .../models/seed_oss/modular_seed_oss.py | 3 +- .../models/smollm3/modeling_smollm3.py | 2 +- .../models/smollm3/modular_smollm3.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- .../models/starcoder2/modular_starcoder2.py | 3 +- .../models/t5gemma/modeling_t5gemma.py | 2 +- .../models/vaultgemma/modeling_vaultgemma.py | 2 +- .../video_llama_3/modeling_video_llama_3.py | 3 +- .../video_llama_3/modular_video_llama_3.py | 3 +- .../models/zamba2/modeling_zamba2.py | 3 +- .../models/zamba2/modular_zamba2.py | 3 +- 104 files changed, 175 insertions(+), 217 deletions(-) diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 61f7601f7f63..1780b6383cc4 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -261,7 +261,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index d75a9137d09f..a60daa2f8194 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -34,6 +34,7 @@ LlamaPreTrainedModel, LlamaRMSNorm, LlamaRotaryEmbedding, + apply_rotary_pos_emb, eager_attention_forward, ) from ..nemotron.modeling_nemotron import NemotronMLP @@ -224,7 +225,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index a0aae79d3899..7a5ef3370088 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -348,7 +348,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BambaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, *args, **kwargs): + def __init__(self, config: BambaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index f31518bd6ed5..d29273940b8a 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -198,9 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BambaAttention(LlamaAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.rotary_fn = apply_rotary_pos_emb + pass class BambaRMSNormGated(MambaRMSNormGated): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index eb2807657950..25568ca92365 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -196,7 +196,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index 2ca747807afa..093eb2428395 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -58,7 +58,6 @@ class BitNetAttention(LlamaAttention): def __init__(self, config: BitNetConfig, layer_idx: int): super().__init__(config, layer_idx) self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -77,7 +76,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 556b007ce320..752b693fe7b8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -283,7 +283,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index acb44b9dbff2..5147e5638eb2 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -144,7 +144,6 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.k_norm = CohereLayerNorm( hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -171,7 +170,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 5e0423c2875d..a9c56cd2491c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -225,7 +225,6 @@ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -245,7 +244,7 @@ def forward( cos, sin = position_embeddings if self.sliding_window is not None: - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index afc3f0e8cdbb..9e8bc6b564e4 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -270,7 +270,6 @@ def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -290,7 +289,7 @@ def forward( cos, sin = position_embeddings if self.sliding_window is not None: - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 39ae259fe0a7..53cc8c23a8c6 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -216,7 +216,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index faee15b60f30..92df09947bc4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -209,7 +209,6 @@ def __init__( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -241,7 +240,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index 7edffa0a70a6..42a9079cb012 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -77,7 +77,6 @@ def __init__( self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -109,7 +108,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 2d52144add58..4c56277c69dd 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -406,7 +406,6 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) - self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": @@ -448,7 +447,7 @@ def forward( if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: - q_rot, k_rot = self.rotary_fn(q_rot, k_rot, cos, sin) + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 6f744389bc35..97c3f0cd425b 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -208,7 +208,6 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) - self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": @@ -250,7 +249,7 @@ def forward( if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: - q_rot, k_rot = self.rotary_fn(q_rot, k_rot, cos, sin) + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 52210d35135a..b23ee2ed948d 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -220,7 +220,6 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -245,7 +244,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -330,7 +329,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -455,7 +454,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 13a763f54fdd..904881ed7fbd 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -95,7 +95,6 @@ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -120,7 +119,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -205,7 +204,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -330,7 +329,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index ad0cba3f689a..09edeed17543 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -274,7 +274,6 @@ def __init__(self, config: DINOv3ViTConfig): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -296,7 +295,7 @@ def forward( value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index 85bf818c1bf9..7ae7d1632053 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -231,7 +231,6 @@ def __init__(self, config: DINOv3ViTConfig): self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias) self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -253,7 +252,7 @@ def forward( value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 303808e50d20..a3a94f88df55 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -289,7 +289,6 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): ) self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -309,7 +308,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index 629273fe87f2..eacea60cf442 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -325,7 +325,6 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): ) self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -345,7 +344,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 8a1cdfd06fed..9092a3533e43 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -249,7 +249,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5/modular_ernie4_5.py b/src/transformers/models/ernie4_5/modular_ernie4_5.py index 3d41f86c951d..780b07164ec0 100644 --- a/src/transformers/models/ernie4_5/modular_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modular_ernie4_5.py @@ -101,7 +101,6 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) - self.rotary_fn = apply_rotary_pos_emb class Ernie4_5ForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index f6e8c7966ce2..2287f9f8e88d 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -231,7 +231,6 @@ def __init__(self, config: Exaone4Config, layer_idx: int): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -257,7 +256,7 @@ def forward( cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = { diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index 333a966592ba..d6004db0d28c 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -258,7 +258,6 @@ def __init__(self, config: Exaone4Config, layer_idx: int): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -284,7 +283,7 @@ def forward( cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = { diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index d3effa1d6915..1cf19c500737 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -406,7 +406,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index f191df88dfdb..386266b89bf3 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -205,7 +205,6 @@ class FalconH1Attention(LlamaAttention): def __init__(self, config: FalconH1Config, layer_idx: int): super().__init__(config, layer_idx) self.key_multiplier = config.key_multiplier - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -224,7 +223,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index b4de3e67fffb..993a3dae1652 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -267,7 +267,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 956fc3e6f8e0..d6362634b1e3 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -279,7 +279,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5058d39f9335..f33a0c7ba666 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -313,7 +313,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.is_causal = not getattr(config, "use_bidirectional_attention", False) self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -332,7 +331,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index e59116df9f37..253931d4eeb1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -360,7 +360,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 5c499da4b323..7f9e987035ad 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -450,7 +450,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -472,7 +471,7 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 59fea8fcbb31..3807311ff674 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1291,7 +1291,7 @@ def forward( cos, sin = position_embeddings query_states = self.q_proj(hidden_states).view(hidden_shape) query_states = self.q_norm(query_states) - query_states = self.rotary_fn(query_states, cos, sin, unsqueeze_dim=2) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer @@ -1303,7 +1303,7 @@ def forward( else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) - key_states = self.rotary_fn(key_states, cos, sin, unsqueeze_dim=2) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) key_states = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape) @@ -1497,7 +1497,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a4b7aeca8777..a226824781cf 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1720,7 +1720,6 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( config.layer_types[layer_idx] ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1737,7 +1736,7 @@ def forward( cos, sin = position_embeddings query_states = self.q_proj(hidden_states).view(hidden_shape) query_states = self.q_norm(query_states) - query_states = self.rotary_fn(query_states, cos, sin, unsqueeze_dim=2) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer @@ -1749,7 +1748,7 @@ def forward( else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) - key_states = self.rotary_fn(key_states, cos, sin, unsqueeze_dim=2) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) key_states = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 76532939ec7f..059cb296c972 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -126,7 +126,6 @@ class GlmAttention(LlamaAttention): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb class GlmForCausalLM(LlamaForCausalLM): diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index a914f44051c7..84e6dd3bd77d 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -148,6 +148,51 @@ def eager_attention_forward( return attn_output, attn_weights +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + class Glm4MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -202,7 +247,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 4755872ebcab..fc7d6fd40a80 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -201,7 +201,6 @@ def __init__(self, config, layer_idx=None): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -219,7 +218,7 @@ def forward( query_states, key_states, value_states = qkv.chunk(3, dim=-1) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Cache QKV values if layer_past is not None: diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index bc7d222fda5b..c267753db350 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -153,7 +153,6 @@ def __init__(self, config, layer_idx=None): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -171,7 +170,7 @@ def forward( query_states, key_states, value_states = qkv.chunk(3, dim=-1) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Cache QKV values if layer_past is not None: diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 745b06ae56a2..5e1173d823d0 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -347,7 +347,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index a0767155415a..57acfea8df64 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -254,7 +254,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -274,7 +273,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/helium/modular_helium.py b/src/transformers/models/helium/modular_helium.py index b99477992385..79d995720e30 100644 --- a/src/transformers/models/helium/modular_helium.py +++ b/src/transformers/models/helium/modular_helium.py @@ -99,7 +99,6 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.scaling = 1 / math.sqrt(self.head_dim) - self.rotary_fn = apply_rotary_pos_emb class HeliumDecoderLayer(LlamaDecoderLayer): diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 66709468c089..28899e306f4c 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -199,7 +199,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index 2462ab746e4a..d41b5f236759 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -66,7 +66,6 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): super().__init__(config, layer_idx) self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -85,7 +84,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 7986b4a92daa..dda4366f0d4d 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -198,7 +198,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 2abf15e24649..f94c4ed4c2c7 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -36,6 +36,7 @@ LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, + apply_rotary_pos_emb, eager_attention_forward, ) from ..mixtral.modeling_mixtral import MixtralExperts @@ -80,7 +81,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index d21f4159e285..1ef99fc46d7f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -469,7 +469,6 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.experts = JetMoeMoA(config) self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -491,7 +490,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/jetmoe/modular_jetmoe.py b/src/transformers/models/jetmoe/modular_jetmoe.py index 94dd2d508300..db8c3e1059c0 100644 --- a/src/transformers/models/jetmoe/modular_jetmoe.py +++ b/src/transformers/models/jetmoe/modular_jetmoe.py @@ -324,7 +324,6 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.experts = JetMoeMoA(config) self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -346,7 +345,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 2d6c726672cd..aa66a2e2d80c 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -395,7 +395,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 3a12d12b004b..0075280e6ddb 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -33,6 +33,7 @@ LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, + apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_lfm2 import Lfm2Config @@ -243,7 +244,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index e4d54cd78abe..55dcde8b07e9 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -466,7 +466,7 @@ def forward( value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index 83adcea63a6d..e684b11130d8 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -27,7 +27,6 @@ from torch.nn.utils.rnn import pad_sequence from ...activations import ACT2FN -from ...integrations import use_kernel_func_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -100,13 +99,13 @@ def forward( def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x -@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -127,11 +126,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ + dtype = q.dtype + q = q.float() + k = k.float() cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -222,7 +224,7 @@ def forward( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index 5bd7b30ee324..f61e86a67e0d 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -29,6 +29,7 @@ from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModelForKeypointDetection from ..clip.modeling_clip import CLIPMLP +from ..cohere.modeling_cohere import apply_rotary_pos_emb from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from ..superglue.image_processing_superglue import ( SuperGlueImageProcessor, @@ -283,7 +284,7 @@ def forward( if position_embeddings is not None: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 8c638168c953..4135bce33d83 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -251,34 +251,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -@use_kernel_func_from_hub("rotary_pos_emb") -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -404,7 +376,6 @@ def __init__(self, config, layer_idx: int): config.hidden_size, bias=config.attention_bias, ) - self.rotary_fn = apply_rotary_pos_emb self.scaling = self.qk_head_dim ** (-0.5) if self.config.rope_parameters.get("rope_type", "default") != "default": diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index a4f1a5e7f22c..8b8f8c9adb3a 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -427,7 +427,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index 75e200ddf5d6..ccf6c9abcda5 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -158,7 +158,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c97fd5b6880a..43dc5c18a90b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -156,7 +156,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 215499d577a9..709ff855c399 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -50,7 +50,6 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -69,7 +68,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c693e6f4d8f9..95b236dadce6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -326,7 +326,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index 2379222e696a..72e26db9bd1c 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -254,7 +254,6 @@ def __init__(self, config: MLCDVisionConfig): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.num_key_value_groups = config.num_key_value_groups - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -274,7 +273,7 @@ def forward( # Apply positional embeddings cos = position_embeddings[0].unsqueeze(0).float() sin = position_embeddings[1].unsqueeze(0).float() - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) # Each of shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.permute(0, 2, 1, 3).contiguous() diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index f9b3b5a73d77..4cfc743948ef 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -201,7 +201,6 @@ def __init__(self, config: MLCDVisionConfig): super().__init__(config) self.num_key_value_groups = config.num_key_value_groups self.is_causal = False - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -220,7 +219,7 @@ def forward( # Apply positional embeddings cos = position_embeddings[0].unsqueeze(0).float() sin = position_embeddings[1].unsqueeze(0).float() - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) # Each of shape: [batch_size, num_heads, seq_length, head_dim] query_states = query_states.permute(0, 2, 1, 3).contiguous() diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 42a744f4ae28..4869da818e01 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -377,7 +377,7 @@ def eager_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = module.rotary_fn(query, key, cos, sin) + query, key = apply_rotary_pos_emb(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale @@ -457,7 +457,7 @@ def sdpa_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = module.rotary_fn(query, key, cos, sin) + query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask @@ -532,7 +532,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() - self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 2cde78b24962..8fc375671583 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -556,7 +556,7 @@ def eager_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = module.rotary_fn(query, key, cos, sin) + query, key = apply_rotary_pos_emb(query, key, cos, sin) scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale @@ -636,7 +636,7 @@ def sdpa_attention_forward( cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] - query, key = module.rotary_fn(query, key, cos, sin) + query, key = apply_rotary_pos_emb(query, key, cos, sin) if local_attention != (-1, -1): attention_mask = sliding_window_mask @@ -711,7 +711,6 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() - self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 026de89b7202..5059e623c9dd 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -274,7 +274,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N self.out_drop = nn.Dropout(config.attention_dropout) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -293,7 +292,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index a21ca2f61ccb..b39cb440c3e9 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -311,7 +311,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: Optional[int] = N self.out_drop = nn.Dropout(config.attention_dropout) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -330,7 +329,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 47e178430542..87893284ec4e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -323,7 +323,7 @@ def forward( if not is_cross_attention: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 9ea67f52b737..5afbdb2db363 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -274,7 +274,6 @@ def __init__( self.head_dim_padding = target_head_dim - self.head_dim else: self.head_dim_padding = 0 - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -325,7 +324,7 @@ def forward( if not is_cross_attention: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b56add044f3a..1dcaa6247155 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -266,7 +266,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 4c6b2f7eeb15..4f2796837cb5 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -114,10 +114,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class OlmoAttention(LlamaAttention): - def __init__(self, config: OlmoConfig, layer_idx: int): - super().__init__(config, layer_idx) - self.rotary_fn = apply_rotary_pos_emb - def forward( self, hidden_states: torch.Tensor, @@ -145,7 +141,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index c2e4071c5f3d..aa6296f344f1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -256,7 +256,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 2b989e0c3841..152a78c4e462 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -211,7 +211,6 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -235,7 +234,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index cc5369cceb51..003e691e9c82 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -189,7 +189,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmo3/modular_olmo3.py b/src/transformers/models/olmo3/modular_olmo3.py index d022ac3c2c2f..265ef3bae967 100644 --- a/src/transformers/models/olmo3/modular_olmo3.py +++ b/src/transformers/models/olmo3/modular_olmo3.py @@ -214,7 +214,6 @@ def __init__(self, config: Olmo3Config, layer_idx: int): assert config.layer_types is not None self.attention_type = config.layer_types[layer_idx] self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -237,7 +236,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 6d929138f613..d04cd421d441 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -270,7 +270,7 @@ def forward( key_states = key_states.view(*hidden_shape).transpose(1, 2) value_states = value_states.view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index c6a8eb1fe7d7..eef444e6f24a 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -63,7 +63,6 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): self.k_norm = OlmoeRMSNorm( (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -90,7 +89,7 @@ def forward( key_states = key_states.view(*hidden_shape).transpose(1, 2) value_states = value_states.view(*hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index ac7638e8741e..eec8794a3d65 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -231,7 +231,7 @@ def forward( key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = self.rotary_fn(query_rot, key_rot, cos, sin) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 8f10a6f4af3e..3ecc9ba9d4f7 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -84,7 +84,6 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.k_layernorm = nn.LayerNorm( config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -118,7 +117,7 @@ def forward( key_states[..., self.rotary_ndims :], ) # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = self.rotary_fn(query_rot, key_rot, cos, sin) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) # [batch_size, seq_length, num_heads, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f7ec566ef491..29b3d2847ed1 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -226,7 +226,6 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -251,7 +250,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index dee9daa0135c..5b0d5f76d69c 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -118,7 +118,6 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -143,7 +142,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 070d47efdcff..eab15068d252 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1259,7 +1259,6 @@ def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -1284,7 +1283,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 28316fc97d33..060683cfb972 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -222,7 +222,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index efb99a6381b2..b06d3182b273 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -58,7 +58,6 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -77,7 +76,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index f4503eaa527f..0826873a8f98 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -908,7 +908,6 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: self.config = config self.attention_dropout = 0.0 self.is_causal = False - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -921,8 +920,8 @@ def forward( query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = self.rotary_fn(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - key_states = self.rotary_fn(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) @@ -3034,7 +3033,6 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.to_v = nn.Linear(config.hidden_size, self.inner_dim) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3059,7 +3057,7 @@ def forward( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = self.rotary_fn(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index b0f0b96f62da..685c6b5c86f0 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1840,7 +1840,6 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None: self.config = config self.attention_dropout = 0.0 self.is_causal = False - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -1853,8 +1852,8 @@ def forward( query_states = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) - query_states = self.rotary_fn(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - key_states = self.rotary_fn(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) + key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) @@ -3208,7 +3207,6 @@ def __init__(self, config: Qwen2_5OmniDiTConfig): self.to_v = nn.Linear(config.hidden_size, self.inner_dim) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -3233,7 +3231,7 @@ def forward( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = self.rotary_fn(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_weights, _ = attention_interface( diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 65d5d04090aa..9b6c41662929 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -269,7 +269,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index c4b9ee5a7ed0..2f113785a94e 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -36,6 +36,7 @@ Qwen2ForTokenClassification, Qwen2RMSNorm, Qwen2RotaryEmbedding, + apply_rotary_pos_emb, eager_attention_forward, ) from .configuration_qwen3 import Qwen3Config @@ -83,7 +84,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 0d09deb61ef2..caf7e26a39fe 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -168,7 +168,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index f757ca3166c2..27581014e9f7 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -399,7 +399,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 40b30e929193..7deedb9c868b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -227,7 +227,6 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) del self.sliding_window - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -251,7 +250,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 7883e6afd4b9..73709df5d0bc 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1494,7 +1494,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -2374,7 +2374,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -3402,7 +3402,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 7b697f5639f7..9922ccb09bb6 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -462,7 +462,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 4bc849dfb0f0..60253ce21551 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -405,7 +405,6 @@ class Qwen3VLTextAttention(Qwen3Attention): def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): super().__init__(config, layer_idx) del self.sliding_window - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -424,7 +423,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 0b60dc4ba3f3..28e2c85f156c 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -277,7 +277,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 92597e7e91ca..877d5eaa4fc0 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -183,7 +183,6 @@ def __init__(self, config: SeedOssConfig, layer_idx: int): ) self.residual_dropout = config.residual_dropout - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -202,7 +201,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/seed_oss/modular_seed_oss.py b/src/transformers/models/seed_oss/modular_seed_oss.py index 7f5a218e74dc..96cb4d428214 100644 --- a/src/transformers/models/seed_oss/modular_seed_oss.py +++ b/src/transformers/models/seed_oss/modular_seed_oss.py @@ -94,7 +94,6 @@ def __init__(self, config: SeedOssConfig, layer_idx: int): ) self.residual_dropout = config.residual_dropout - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -113,7 +112,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 8819b361f846..c59906f63692 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -236,7 +236,7 @@ def forward( if self.use_rope: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index d0f6d08c6279..995c15f10de0 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -238,7 +238,6 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention" else None ) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -258,7 +257,7 @@ def forward( if self.use_rope: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3510365ed95e..c013e97f2169 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -177,7 +177,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index e6e048e18016..2c0a27f81bdf 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -76,7 +76,6 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -95,7 +94,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 437f172936df..5c0410655ee4 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -286,7 +286,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 9a7c174a34a8..ad93f819fa0e 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -210,7 +210,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 61e835f11b3f..8442a4a7c175 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -200,7 +200,6 @@ def __init__(self, config): self.num_key_value_groups = 1 self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -224,7 +223,7 @@ def forward( value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index fbc4d9259939..0bd05ec9acf8 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -278,7 +278,6 @@ def __init__(self, config): self.attention_dropout = config.attention_dropout del self.scale del self.dropout - self.rotary_fn = apply_rotary_pos_emb_vision def forward( self, @@ -302,7 +301,7 @@ def forward( value_states = self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ef31502c40ec..9ade9d1dfc8d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -416,7 +416,6 @@ def __init__( self.linear_v_adapter_list.append(linear_v_adapter) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -446,7 +445,7 @@ def forward( if self.config.use_mem_rope: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c744bb1e3954..b4761a50dd42 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -224,7 +224,6 @@ def __init__( self.linear_v_adapter_list.append(linear_v_adapter) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} - self.rotary_fn = apply_rotary_pos_emb def forward( self, @@ -254,7 +253,7 @@ def forward( if self.config.use_mem_rope: cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) From 4b9da306ad4742fb4fa7239b30653b84f588e3ca Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 28 Nov 2025 16:07:44 +0000 Subject: [PATCH 32/32] update code Signed-off-by: Liu, Kaixuan --- .../models/arcee/modeling_arcee.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- .../models/bamba/modeling_bamba.py | 2 +- src/transformers/models/csm/modeling_csm.py | 2 +- src/transformers/models/dia/modeling_dia.py | 39 ++++++++++++++++++- src/transformers/models/emu3/modeling_emu3.py | 2 +- .../models/ernie4_5/modeling_ernie4_5.py | 2 +- .../ernie4_5_moe/modeling_ernie4_5_moe.py | 2 +- .../models/evolla/modeling_evolla.py | 2 +- .../models/gemma/modeling_gemma.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/glm4/modeling_glm4.py | 2 +- .../models/granite/modeling_granite.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../modeling_granitemoeshared.py | 2 +- .../models/helium/modeling_helium.py | 2 +- .../models/llama/modeling_llama.py | 2 +- .../models/phimoe/modeling_phimoe.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- 19 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 15a10957c7f3..d7faffd25c6e 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -264,7 +264,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c62becaf93fb..998f770a4562 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -488,7 +488,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 7a5ef3370088..7f6cceeb8372 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -389,7 +389,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 80a4b3d1c6fe..3c58c7f6d45c 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -316,7 +316,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index a29eda2ff74f..55bbfd96411d 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -193,6 +193,41 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -269,7 +304,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index a0eafe4db0dc..d3d272aa8fdd 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -162,7 +162,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 4f233c456b14..c18f19206b21 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -240,7 +240,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 63154a034350..75c36992b22d 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -263,7 +263,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 7a05cb0a4062..2260204eb150 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -1135,7 +1135,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index d3280b45834d..ae5e25202a1d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -263,7 +263,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 07ca34658c5e..aefb879dc659 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -258,7 +258,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index d6d7fe16cd59..3bfec8706d8f 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -240,7 +240,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 6e7969548aa7..ae2d017d53bf 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -160,7 +160,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 44eb6d705b2f..d50147d5c686 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -382,7 +382,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index df9af83e23a9..78bcdd03c43e 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -372,7 +372,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index c075c5ed06a9..97a0f89b212b 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -262,7 +262,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5aa004cee8c7..eb938b6cbf7e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -268,7 +268,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 0b71816b0c86..fc7c733f3271 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -238,7 +238,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8ea900211aae..389bf016243a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -265,7 +265,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = self.rotary_fn(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache