diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 15dd7518150c..05f410c4437f 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -72,6 +72,7 @@ "register_kernel_mapping", "replace_kernel_forward_from_hub", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", ], "integration_utils": [ "INTEGRATION_TO_CALLBACK", @@ -212,6 +213,7 @@ register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, + use_kernel_func_from_hub, ) from .integration_utils import ( INTEGRATION_TO_CALLBACK, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 0ab866ecbd6d..bada38965aed 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -34,6 +34,23 @@ register_kernel_mapping, replace_kernel_forward_from_hub, ) + from kernels import ( + use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub, + ) + + # Try to import FuncRepository, fallback if not available + try: + from kernels import FuncRepository + except ImportError: + FuncRepository = None + + # Try to import use_kernel_func_from_hub, fallback if not available + try: + from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub + + _has_use_kernel_func_from_hub = True + except ImportError: + _has_use_kernel_func_from_hub = False _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper() _kernels_available = True @@ -41,8 +58,6 @@ def use_kernel_forward_from_hub(layer_name: str): if _kernels_enabled: - from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub - return _kernels_use_kernel_forward_from_hub(layer_name) else: logger.warning_once( @@ -50,6 +65,21 @@ def use_kernel_forward_from_hub(layer_name: str): ) return lambda cls: cls + def use_kernel_func_from_hub(func_name: str): + if _kernels_enabled and _has_use_kernel_func_from_hub: + return _kernels_use_kernel_func_from_hub(func_name) + else: + if not _has_use_kernel_func_from_hub: + logger.warning_once( + "use_kernel_func_from_hub is not available in the installed kernels version. " + "Please upgrade kernels to use this feature." + ) + else: + logger.warning_once( + f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}" + ) + return lambda func: func + _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = { "MultiScaleDeformableAttention": { "cuda": LayerRepository( @@ -164,6 +194,16 @@ def use_kernel_forward_from_hub(layer_name: str): }, } + # Add function kernel mappings if FuncRepository is available + if FuncRepository is not None: + _KERNEL_MAPPING["rotary_pos_emb"] = { + "xpu": { + Mode.INFERENCE: FuncRepository( + repo_id="kernels-community/rotary", func_name="apply_rotary_transformers" + ) + } + } + def has_key(d, key): return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values()) @@ -189,6 +229,12 @@ def decorator(cls): return decorator + def use_kernel_func_from_hub(*args, **kwargs): + def decorator(func): + return func + + return decorator + class LayerRepository: def __init__(self, *args, **kwargs): raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") @@ -201,6 +247,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + def register_kernel_mapping_transformers(*args, **kwargs): + raise RuntimeError( + "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`." + ) + _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = { "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"}, @@ -319,6 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ __all__ = [ "LayerRepository", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 77a3d65478d6..1780b6383cc4 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -147,6 +147,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 779f4a63e378..d7faffd25c6e 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -154,6 +154,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 96a6a82da91d..998f770a4562 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -378,6 +378,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 2428222e0dbe..7f6cceeb8372 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 73597dd98d82..25568ca92365 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -85,6 +85,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9fc2593d3175..752b693fe7b8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 87da76281717..3c58c7f6d45c 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -32,7 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -206,6 +206,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index df9760ed1ba7..53cc8c23a8c6 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -113,6 +113,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int): self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index ddf5fce4dfce..92df09947bc4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -112,6 +113,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index cfd8d91dfb9a..4c56277c69dd 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -16,7 +16,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -253,6 +253,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 3a0ddf6e3f90..55bbfd96411d 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -200,6 +200,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 99524915b9f6..b23ee2ed948d 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -32,7 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_layers import ( @@ -141,6 +141,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index b9ebf9856264..a3a94f88df55 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -33,7 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...integrations.flex_attention import compile_friendly_flex_attention from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -143,6 +143,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index b8ae00b6ab60..9092a3533e43 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -135,6 +135,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -226,6 +227,7 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 65671913b27f..d3d272aa8fdd 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -33,7 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -52,6 +52,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -142,6 +143,7 @@ def __init__(self, config: Emu3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index b53ddf923e70..c18f19206b21 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -221,6 +221,7 @@ def __init__(self, config: Ernie4_5Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index ccd05fe26347..75c36992b22d 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -244,6 +244,7 @@ def __init__(self, config: Ernie4_5_MoeConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index f4d0ce11255f..2260204eb150 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -1051,6 +1051,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1115,6 +1116,7 @@ def __init__(self, config: EvollaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index cb70c9cff142..2287f9f8e88d 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -140,6 +140,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index a6fd7a5aba99..1cf19c500737 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -36,7 +36,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -295,6 +295,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -385,6 +386,7 @@ def __init__(self, config: FalconH1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.key_multiplier = config.key_multiplier def forward( diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index b948b420ad63..993a3dae1652 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -241,6 +241,7 @@ def __init__(self, config: FlexOlmoConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8834ba8c3564..ae5e25202a1d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -152,6 +153,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -242,6 +244,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 69e486032107..d6362634b1e3 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -156,6 +157,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -256,6 +258,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 8e93ef9231b5..253931d4eeb1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer @@ -230,6 +231,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -330,6 +332,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.is_sliding = self.layer_type == "sliding_attention" diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3be6ad39ddca..3807311ff674 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1254,6 +1254,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.is_sliding = self.layer_type == "sliding_attention" @@ -1475,6 +1476,7 @@ def __init__(self, config: Gemma3nConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 8a508e2de54c..aefb879dc659 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -239,6 +239,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index c982c36f9aab..3bfec8706d8f 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -221,6 +221,7 @@ def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 9537d9018838..74016aa98e7e 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -204,6 +204,48 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Interleave them instead of usual shape + cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) + sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -280,6 +322,7 @@ def __init__(self, config: Glm4vMoeTextConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.rope_parameters = config.rope_parameters def forward( diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 8e1ce9df0b97..5e1173d823d0 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -325,6 +325,7 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 42de2e0724f3..ae2d017d53bf 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -50,6 +50,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -140,6 +141,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index f722ad416a2f..d50147d5c686 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -272,6 +272,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -362,6 +363,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c9e7245956f3..9534b45c0624 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -31,7 +31,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -59,6 +59,41 @@ logger = logging.get_logger(__name__) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -122,6 +157,7 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 606a59390e6e..78bcdd03c43e 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -262,6 +262,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -352,6 +353,7 @@ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index e616da3cd07b..97a0f89b212b 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -243,6 +243,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None): config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index 4d184a0b1982..28899e306f4c 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -87,6 +87,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -177,6 +178,7 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 281a50a9e2cc..dda4366f0d4d 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -86,6 +86,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -176,6 +177,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fbda366fc319..47673a3afab2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -32,7 +32,7 @@ from ... import initialization as init from ...activations import ACT2FN from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -175,6 +175,41 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.key_cache[layer_idx].shape[-2] +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -229,6 +264,7 @@ def __init__(self, config: JambaConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index b102a111e10f..1ef99fc46d7f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -366,6 +366,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index f1d639d16bbd..aa66a2e2d80c 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -26,7 +26,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -292,6 +292,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -371,6 +372,7 @@ def __init__(self, config: Lfm2Config, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 73b9c4a8fde0..55dcde8b07e9 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast @@ -363,6 +363,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -442,6 +443,7 @@ def __init__(self, config: Lfm2MoeConfig, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps) diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index 0d5044c5a40a..e684b11130d8 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -199,6 +199,7 @@ def __init__(self, config: LightGlueConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e3adac5d117d..eb938b6cbf7e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -142,6 +142,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -248,6 +249,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 004ed68cef23..8b8f8c9adb3a 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -326,6 +326,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -407,6 +408,7 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index b1c8555fd96b..ccf6c9abcda5 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -13,7 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -54,6 +54,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -137,6 +138,7 @@ def __init__(self, config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 60c7e2d49eed..43dc5c18a90b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -15,7 +15,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -55,6 +55,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -136,6 +137,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1faff1f4dcea..95b236dadce6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -37,7 +37,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -225,6 +225,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -306,6 +307,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 8069f2bec2ff..4869da818e01 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -31,6 +31,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations import use_kernel_func_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -331,6 +332,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 7564e375716b..5059e623c9dd 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast @@ -183,6 +184,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 373e1db4a217..87893284ec4e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -264,6 +264,7 @@ def __init__( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb # Pad head dimension to the next specified multiple. if self.config.pad_head_dim_to_multiple_of is not None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2ba7a25f71b5..1dcaa6247155 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -237,6 +237,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 44e3592157af..aa6296f344f1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -230,6 +230,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index d49570982f48..003e691e9c82 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -161,6 +161,7 @@ def __init__(self, config: Olmo3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) assert config.layer_types is not None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index f078518e0c1f..d04cd421d441 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -148,6 +148,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -238,6 +239,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.k_norm = OlmoeRMSNorm( (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 079307547fe5..d11d56d3a11c 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -29,6 +29,7 @@ from ... import initialization as init from ...activations import ACT2FN +from ...integrations import use_kernel_func_from_hub from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -182,6 +183,41 @@ def forward(self, hidden_states, attention_mask=None): return hidden_states.transpose(1, 2) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -245,6 +281,7 @@ def __init__(self, config: ParakeetEncoderConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb # W_{k,R} projection self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) # global content bias diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4a1530b78564..eec8794a3d65 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -13,6 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -105,6 +106,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -185,6 +187,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.rotary_fn = apply_rotary_pos_emb self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.qk_layernorm = config.qk_layernorm diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 12e41214094d..fc7c733f3271 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -128,6 +128,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -218,6 +219,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb def forward( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1215f3677603..060683cfb972 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -13,7 +13,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -119,6 +119,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -201,6 +202,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None def forward( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8bda140d3cdb..389bf016243a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -35,7 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -161,6 +161,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -243,6 +244,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rotary_fn = apply_rotary_pos_emb if self.config.layer_types[layer_idx] == "sliding_attention": self.sliding_window = config.sliding_window diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 5f0f8974eb0a..9b6c41662929 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -155,6 +155,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -246,6 +247,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 477694d5fb2b..caf7e26a39fe 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -55,6 +55,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -145,6 +146,7 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 362c8fab007f..27581014e9f7 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -371,6 +371,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3NextRMSNorm( self.head_dim, eps=config.rms_norm_eps diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 1be0487cea98..73709df5d0bc 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -35,7 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -1415,6 +1415,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -1467,6 +1468,7 @@ def __init__(self, config, layer_idx): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # unlike olmo, only on the head dim! @@ -2348,6 +2350,7 @@ def __init__(self, config: Qwen3OmniMoeConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3OmniMoeRMSNorm( self.head_dim, eps=config.rms_norm_eps @@ -3377,6 +3380,7 @@ def __init__(self, config: Qwen3OmniMoeCode2WavConfig, layer_idx): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.sliding_window = config.sliding_window diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index aab768b1cf1c..9922ccb09bb6 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -385,6 +385,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -438,6 +439,7 @@ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3VLTextRMSNorm( self.head_dim, eps=config.rms_norm_eps diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index eab677bce4fe..28e2c85f156c 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -198,6 +198,7 @@ def eager_attention_forward( return attn_output, attn_weights +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -251,6 +252,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = Qwen3VLMoeTextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # unlike olmo, only on the head dim! diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 682193ca8d51..877d5eaa4fc0 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import ( GenericForQuestionAnswering, @@ -90,6 +90,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index e23d4993e84c..c59906f63692 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -118,6 +118,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -208,6 +209,7 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.use_rope = config.no_rope_layers[layer_idx] self.sliding_window = ( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 042033fe3565..c013e97f2169 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -35,6 +35,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -74,6 +75,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -155,6 +157,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + self.rotary_fn = apply_rotary_pos_emb self.residual_dropout = config.residual_dropout def forward( diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index ac9b64929280..5c0410655ee4 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -162,6 +163,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -263,6 +265,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None @@ -338,6 +341,7 @@ def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping if config.cross_attention_hidden_size is None: diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index ee36d7519a53..ad93f819fa0e 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -87,6 +88,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -187,6 +189,7 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5b5f532cebf7..9ade9d1dfc8d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -316,6 +317,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.