Skip to content

add rotary kernel support to Qwen3 model#41147

Merged
MekkCyber merged 43 commits intohuggingface:mainfrom
kaixuanliu:rotary-kernel
Nov 28, 2025
Merged

add rotary kernel support to Qwen3 model#41147
MekkCyber merged 43 commits intohuggingface:mainfrom
kaixuanliu:rotary-kernel

Conversation

@kaixuanliu
Copy link
Copy Markdown
Contributor

@kaixuanliu kaixuanliu commented Sep 25, 2025

Adds Rotary kernels from https://huggingface.co/kernels-community/rotary to Qwen3 series models

Here are Some benchmarks comparing perfs between rotary kernels and apply_rotary_pos_emb func in transformers:
For A100,
rotary_a100_combined_comparison
And for Intel XPU:
rotary_xpu_combined_comparison

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as ready for review September 25, 2025 07:45
@kaixuanliu
Copy link
Copy Markdown
Contributor Author

kaixuanliu commented Sep 25, 2025

I made benchmark for Qwen/Qwen3-4B-Instruct-2507 model, and on Intel XPU, it will get ~10% performance improvement for E2E time. While on A100, there is no obvious performance improvement or drop. Pls let me know if it is OK using this manner to apply rotary kernel, and then I will add the support for more models.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as draft September 25, 2025 08:43
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as ready for review September 25, 2025 10:26
@Rocketknight1
Copy link
Copy Markdown
Member

cc @ArthurZucker

Copy link
Copy Markdown
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this integration @kaixuanliu ! I left few nits to consider

Comment on lines +517 to +519
global use_kernels
use_kernels = getattr(self, "use_kernels", False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have an attention kwarg passed use_rotary_kernel for example than defining a global variable like this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean add a param called use_rotary_kernel to kwargs here, and passed it down to Qwen3Attention?

from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations.hub_kernels import rotary_kernel
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to lazily load the kernel, because here we are loading it before even knowing if the user wants to use kernels or not

Copy link
Copy Markdown
Contributor Author

@kaixuanliu kaixuanliu Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for your advice! Have updated related code

Comment on lines +125 to +148
def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Rotary kernel implementation wrapper
Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

q_rotated = q.clone()
k_rotated = k.clone()

# Get half dimension for rotation
half_dim = q.shape[-1] // 2
q1 = q_rotated[..., :half_dim]
q2 = q_rotated[..., half_dim:]
k1 = k_rotated[..., :half_dim]
k2 = k_rotated[..., half_dim:]
if cos.shape[-1] != half_dim:
# Trim cos/sin to match half_dim
cos = cos[..., :half_dim]
sin = sin[..., :half_dim]

# Apply rotary embedding using our kernel
rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try to benchmark the performance with and without this kernel ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on Intel XPU, one single rotary op needs 0.22 ms, and it drops to 0.1 ms after applying this patch. above 2x speedup.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! unfortunately this is not how we want to be adding support for kernels in general!
There should be 0 modeling changes involved, especially here it does not even seem to be required!

We'd rather import once from kernels to replace the rotary embed if the function is defined or something, but in the broad scheme of things, we want a mapping for function ! Like we do for classes

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

kaixuanliu commented Oct 10, 2025

@ArthurZucker Thx for the comment, it makes sense.Since for function level, we do not have a shema for mapping like classes in kernels, we will add related support and then based on this, I will adjust this PR.

@yao-matrix
Copy link
Copy Markdown
Contributor

@ArthurZucker @danieldk , could you comment the feasibility of Kaixuan's proposal?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

I think @MekkCyber is working on that feature specifically!

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

I think @MekkCyber is working on that feature specifically!

@ArthurZucker @MekkCyber , do you mean this PR: #41577 ?

@MekkCyber
Copy link
Copy Markdown
Contributor

MekkCyber commented Oct 17, 2025

Hey @kaixuanliu, yes we will start using the hub mapping in the PR you linked, but the kernel needs to be a drop in replacement for the function in the modeling so we don't have to change the modeling files apart from lazily loading the kernel, in case you need a special function for example in the case of rotary, we can expose it directly in the kernel

@kaixuanliu kaixuanliu marked this pull request as draft October 21, 2025 06:36
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very nice! IDK if we "have" to use self.rotary func? if not would be perfect hehe

Kudos everyone 🚀

self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.rotary_fn = apply_rotary_pos_emb
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need self? if not we can just directly use the func? (i did not follow precisely!)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is generated by modular, if we modify this, it will fail for utils/check_modular_conversion.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I mean for the original model!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self here is necessary, the kernelized function must be included in the module so that calling kernelize on the model can detect it.

return lambda cls: cls

def use_kernel_func_from_hub(func_name: str):
if _kernels_enabled and _has_use_kernel_func_from_hub:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MekkCyber we need some docs here on usage etc!

return attn_output, attn_weights


@use_kernel_func_from_hub("rotary_pos_emb")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets put it on llama and all models that have the same no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice! Since current implemetations will not use the kernels for functions by default as the former version. I think it is ok to add this to all models. Have updated the code.

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as draft November 28, 2025 14:43
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: apertus, arcee, aria, bamba, bitnet, cohere, csm, cwm, dbrx, deepseek_v3, dia, diffllama, doge, dots1, emu3, ernie4_5

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@ArthurZucker ArthurZucker marked this pull request as ready for review November 28, 2025 17:02
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience @kaixuanliu ! lgtm

@MekkCyber MekkCyber merged commit 6587d77 into huggingface:main Nov 28, 2025
24 checks passed
@vasqu vasqu mentioned this pull request Nov 28, 2025
sarathc-cerebras pushed a commit to sarathc-cerebras/transformers that referenced this pull request Dec 7, 2025
* add rotary kernel support to Qwen3 model

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* delete unnecessary import

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* put get rotary kernel to hub_kernels.py

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix wrong import

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* refine code and adjust related modular code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix modular mismatch bug

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code, use lazy load kernels

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix check modular conversion issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix CI bug for qwen3-next

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix CI issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* delete unused code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* rename to `apply_rotary_transformers`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust import `lazy_load_kernel` location

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Update modular-generated modeling files with lazy_load_kernel import location

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix conflicts

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add more check

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use decorator to map kernels for functions

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* small fix

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* small adjustment

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix LINT issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code to adapt to new `use_kernel_func_from_hub` API in kernels

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* do not consider check_modular first

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add compatibility for old version `kernels`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add rotary fn kernel to all models

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update modular part

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Revert "update modular part"

This reverts commit b8b68c7.

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@Cyrilvallez
Copy link
Copy Markdown
Member

Humm, this adds a random self.rotary_fn in the module, which is not used... IMO forward should be changed to use self.rotary_fn then!

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

Hi, @Cyrilvallez , self.rotary_fn is needed here, as curent design for use_kernel_func_from_hub decorator need to bind the function to a Module. Maybe we can add a line of comment here?

@MekkCyber
Copy link
Copy Markdown
Contributor

It's not really necessary to use the self.rotary_fn since it's only used to make the function discoverable by the kernelize process. Btw i'm trying to think of a better way to do that

@Cyrilvallez
Copy link
Copy Markdown
Member

Yeah I know it's needed, was just saying that it's a bit awkward rn as it's not being used! But all good if @MekkCyber is looking for a better way then it can wait in the meantime!

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

Yes, I agree. It would be much better if we use self.rotary_fn to replace apply_rotary_pos_emb forward. But let's wait to see if @MekkCyber has some better design.

SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* add rotary kernel support to Qwen3 model

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* delete unnecessary import

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* put get rotary kernel to hub_kernels.py

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix wrong import

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* refine code and adjust related modular code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix modular mismatch bug

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code, use lazy load kernels

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix check modular conversion issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix CI bug for qwen3-next

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix CI issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* delete unused code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* rename to `apply_rotary_transformers`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust import `lazy_load_kernel` location

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Update modular-generated modeling files with lazy_load_kernel import location

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix conflicts

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add more check

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* use decorator to map kernels for functions

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* small fix

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* small adjustment

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix LINT issue

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code to adapt to new `use_kernel_func_from_hub` API in kernels

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* do not consider check_modular first

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add compatibility for old version `kernels`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add rotary fn kernel to all models

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update modular part

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Revert "update modular part"

This reverts commit b8b68c7.

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
ArthurZucker added a commit to ArthurZucker/transformers that referenced this pull request Apr 13, 2026
When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before huggingface#41147.

Fixes huggingface#45137
ArthurZucker added a commit to ArthurZucker/transformers that referenced this pull request Apr 13, 2026
When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before huggingface#41147.

Fixes huggingface#45137
ArthurZucker added a commit that referenced this pull request Apr 13, 2026
…45414)

* Fix `IndexError: pop from an empty deque` under DeepSpeed ZeRO-3

When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before #41147.

Fixes #45137

* Add dates to new model cards to satisfy check-repository-consistency
ArthurZucker added a commit that referenced this pull request Apr 13, 2026
…45414)

* Fix `IndexError: pop from an empty deque` under DeepSpeed ZeRO-3

When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before #41147.

Fixes #45137

* Add dates to new model cards to satisfy check-repository-consistency
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
…uggingface#45414)

* Fix `IndexError: pop from an empty deque` under DeepSpeed ZeRO-3

When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before huggingface#41147.

Fixes huggingface#45137

* Add dates to new model cards to satisfy check-repository-consistency
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants