Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
69f2ca8
add rotary kernel support to Qwen3 model
kaixuanliu Sep 25, 2025
d2bf5c5
delete unnecessary import
kaixuanliu Sep 25, 2025
b0cbab5
adjust code
kaixuanliu Sep 25, 2025
8dede65
adjust code
kaixuanliu Sep 25, 2025
5c02189
Merge branch 'rotary-kernel' of https://github.com/kaixuanliu/transfo…
kaixuanliu Sep 25, 2025
137069b
put get rotary kernel to hub_kernels.py
kaixuanliu Sep 25, 2025
8ac3e1e
fix wrong import
kaixuanliu Sep 25, 2025
29f83f2
refine code and adjust related modular code
kaixuanliu Sep 26, 2025
7729b7f
Merge branch 'main' into rotary-kernel
kaixuanliu Sep 26, 2025
94e4f60
fix modular mismatch bug
kaixuanliu Sep 26, 2025
b96a7c9
Merge branch 'rotary-kernel' of https://github.com/kaixuanliu/transfo…
kaixuanliu Sep 26, 2025
aebac76
Merge branch 'main' into rotary-kernel
kaixuanliu Oct 20, 2025
af67a74
update code, use lazy load kernels
kaixuanliu Oct 21, 2025
28c69d3
fix check modular conversion issue
kaixuanliu Oct 21, 2025
cff1580
Merge branch 'main' into rotary-kernel
kaixuanliu Oct 22, 2025
7ac16a1
fix CI bug for qwen3-next
kaixuanliu Oct 22, 2025
adce121
fix CI issue
kaixuanliu Oct 22, 2025
b4757f4
delete unused code
kaixuanliu Oct 22, 2025
5ff6b16
Merge branch 'main' into rotary-kernel
kaixuanliu Oct 23, 2025
3bbfd64
rename to `apply_rotary_transformers`
kaixuanliu Oct 23, 2025
ec4ca1d
Merge branch 'main' into rotary-kernel
kaixuanliu Oct 29, 2025
caa549a
adjust import `lazy_load_kernel` location
kaixuanliu Oct 29, 2025
08a9959
Update modular-generated modeling files with lazy_load_kernel import …
kaixuanliu Oct 29, 2025
7915fc1
fix conflicts
kaixuanliu Oct 29, 2025
6f2d958
add more check
kaixuanliu Oct 31, 2025
e67ca60
Merge branch 'main' into rotary-kernel
kaixuanliu Nov 3, 2025
36c2205
Merge branch 'main' into rotary-kernel
kaixuanliu Nov 11, 2025
f4b12a7
use decorator to map kernels for functions
kaixuanliu Nov 19, 2025
6a3e6f3
small fix
kaixuanliu Nov 19, 2025
702dc09
small adjustment
kaixuanliu Nov 19, 2025
fdabd60
Merge branch 'main' into rotary-kernel
kaixuanliu Nov 19, 2025
fe2bf42
update code
kaixuanliu Nov 19, 2025
6f95969
fix LINT issue
kaixuanliu Nov 19, 2025
2ebea7b
Merge branch 'main' into rotary-kernel
kaixuanliu Nov 27, 2025
e80cfd3
update code to adapt to new `use_kernel_func_from_hub` API in kernels
kaixuanliu Nov 27, 2025
c771acd
do not consider check_modular first
kaixuanliu Nov 27, 2025
d916ef0
update
kaixuanliu Nov 27, 2025
8670efe
fix
kaixuanliu Nov 27, 2025
fe20bd5
add compatibility for old version `kernels`
kaixuanliu Nov 27, 2025
898e36e
add rotary fn kernel to all models
kaixuanliu Nov 28, 2025
b8b68c7
update modular part
kaixuanliu Nov 28, 2025
af25ce0
Revert "update modular part"
kaixuanliu Nov 28, 2025
4b9da30
update code
kaixuanliu Nov 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
Expand Down Expand Up @@ -212,6 +213,7 @@
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
use_kernel_func_from_hub,
)
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
Expand Down
56 changes: 54 additions & 2 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,52 @@
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from kernels import (
use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
)

# Try to import FuncRepository, fallback if not available
try:
from kernels import FuncRepository
except ImportError:
FuncRepository = None

# Try to import use_kernel_func_from_hub, fallback if not available
try:
from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub

_has_use_kernel_func_from_hub = True
except ImportError:
_has_use_kernel_func_from_hub = False

_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
_kernels_available = True
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES

def use_kernel_forward_from_hub(layer_name: str):
if _kernels_enabled:
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub

return _kernels_use_kernel_forward_from_hub(layer_name)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda cls: cls

def use_kernel_func_from_hub(func_name: str):
if _kernels_enabled and _has_use_kernel_func_from_hub:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MekkCyber we need some docs here on usage etc!

return _kernels_use_kernel_func_from_hub(func_name)
else:
if not _has_use_kernel_func_from_hub:
logger.warning_once(
"use_kernel_func_from_hub is not available in the installed kernels version. "
"Please upgrade kernels to use this feature."
)
else:
logger.warning_once(
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
)
return lambda func: func

_KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = {
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
Expand Down Expand Up @@ -164,6 +194,16 @@ def use_kernel_forward_from_hub(layer_name: str):
},
}

# Add function kernel mappings if FuncRepository is available
if FuncRepository is not None:
_KERNEL_MAPPING["rotary_pos_emb"] = {
"xpu": {
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
)
}
}

def has_key(d, key):
return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())

Expand All @@ -189,6 +229,12 @@ def decorator(cls):

return decorator

def use_kernel_func_from_hub(*args, **kwargs):
def decorator(func):
return func

return decorator

class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
Expand All @@ -201,6 +247,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")

def register_kernel_mapping_transformers(*args, **kwargs):
raise RuntimeError(
"register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
)


_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
Expand Down Expand Up @@ -319,6 +370,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
__all__ = [
"LayerRepository",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"register_kernel_mapping",
"register_kernel_mapping_transformers",
"replace_kernel_forward_from_hub",
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/apertus/modeling_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -147,6 +147,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
Expand Down Expand Up @@ -154,6 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb

def forward(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -378,6 +378,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb

def forward(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/bitnet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -85,6 +85,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb
self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/csm/modeling_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -206,6 +206,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb

def forward(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -113,6 +113,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int):
self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None

def forward(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
Expand Down Expand Up @@ -112,6 +113,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
Expand Down Expand Up @@ -253,6 +253,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/dia/modeling_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_bidirectional_mask, create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -200,6 +200,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ...modeling_layers import (
Expand Down Expand Up @@ -141,6 +141,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/doge/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
Expand Down Expand Up @@ -143,6 +143,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -135,6 +135,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(self, config: Dots1Config, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb
self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
Expand Down
Loading