Skip to content

🚨 [Kernels] Fix kernel function registration#45420

Merged
vasqu merged 5 commits intohuggingface:mainfrom
vasqu:fix-kernels-loading-order
Apr 20, 2026
Merged

🚨 [Kernels] Fix kernel function registration#45420
vasqu merged 5 commits intohuggingface:mainfrom
vasqu:fix-kernels-loading-order

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented Apr 13, 2026

Breaking change

🚨 Slightly breaking change: We no longer register the hidden rotary_fn. Users shouldn't have relied on those but in any case marking it, e.g. self.rotary_fn(...) within the Attention module does not work anymore as the reference is deleted from now on

Description

As per title, we do not want to have proper nn.Modules to be registered for kernels exchanged functions - they are not proper modules (and they are never called as such)! They act as exchange format for kernels but functionally they should stay as pure functions only.

The exact reasons are numerous, but one recent example is deepspeed zero 3 which cannot handle this as the module is never properly called in the forward on the module directly (untracable) and it changes module structures after model construction (fixable by changing order of inits tbh).

This PR changes the core functionality to make the module registration temporarily under the parent module, discover the exchangable functions, and delete them from the visible interface. For BC purposes, we still keep a self reference that already exists (now as simple attribute, not module).

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Mostly nits to avoid relying on internals too much!

Comment thread src/transformers/integrations/hub_kernels.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
@vasqu vasqu marked this pull request as ready for review April 17, 2026 19:08
Comment on lines +228 to +230
Mode.TRAINING: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this wasnt included for training before but it runs with deepspeed just fine + https://github.com/Dao-AILab/flash-attention/blob/b65ae6b175f2438de55601695b6a21971fc5e429/flash_attn/layers/rotary.py#L38-L90

Comment on lines +4476 to +4483
def attach_hidden_kernels(module):
for name, fn in getattr(module, "_hidden_kernels", {}).items():
if name not in dict(module.named_children()):
module.register_module(name, fn)

def detach_hidden_kernels(module):
for name in getattr(module, "_hidden_kernels", {}):
delattr(module, name)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the internals structure and rely on native APIs instead as suggested

@vasqu vasqu changed the title [Kernels] Fix kernel function registration 🚨 [Kernels] Fix kernel function registration Apr 17, 2026
@vasqu vasqu requested a review from Cyrilvallez April 17, 2026 19:15
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +4485 to +4486
self.apply(attach_hidden_kernels)
try:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would put the apply inside the try as well, but not a big deal at all!

@vasqu vasqu enabled auto-merge April 20, 2026 12:58
@vasqu vasqu added this pull request to the merge queue Apr 20, 2026
Merged via the queue into huggingface:main with commit 253809c Apr 20, 2026
28 checks passed
@vasqu vasqu deleted the fix-kernels-loading-order branch April 20, 2026 13:29
lvliang-intel pushed a commit to lvliang-intel/transformers that referenced this pull request Apr 21, 2026
* fix attmpt

* proper fix - also works with deepspeed

* rely less on internals and add rotary to training

* move under the try as well
artem-spector pushed a commit to artem-spector/transformers that referenced this pull request Apr 21, 2026
* fix attmpt

* proper fix - also works with deepspeed

* rely less on internals and add rotary to training

* move under the try as well
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants