Skip to content

n-to-1 kernel fusion via KernelConfig#45363

Draft
michaelbenayoun wants to merge 10 commits intohuggingface:mainfrom
michaelbenayoun:fused_kernels
Draft

n-to-1 kernel fusion via KernelConfig#45363
michaelbenayoun wants to merge 10 commits intohuggingface:mainfrom
michaelbenayoun:fused_kernels

Conversation

@michaelbenayoun
Copy link
Copy Markdown
Member

@michaelbenayoun michaelbenayoun commented Apr 10, 2026

What does this PR do?

This PR adds support for fusing multiple modules into a single kernel — the motivating case being fused RMSNorm+MLP kernels, but the API is generic.

What changed

  • FusedModuleBase, fuse_modules, unfuse_modules, register_fusion_patterns added to hub_kernels.py
  • KernelConfig now accepts tuple keys that trigger fusion before kernelization

Two ways to use it

Option A — inline (no model changes needed)

Embed the glob patterns directly in the KernelConfig key as (name, path) pairs:

KernelConfig({
    (
        ("RMSNorm", "model.layers.*.post_attention_layernorm"),
        ("MLP",     "model.layers.*.mlp"),
    ): "org/repo:RMSNormMLP",
})

Option B — via registry (model declares its patterns)

The model class declares where its fusable modules live:

class MyModel(PreTrainedModel):
    _kernel_fusion_patterns = {
        "RMSNormMLP": ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"],
    }

Or externally without touching the class:

register_fusion_patterns(MyModel, {"RMSNormMLP": [...]})

Then the KernelConfig key is just the kernel name:

KernelConfig({("RMSNorm", "MLP"): "org/repo:RMSNormMLP"})

Example

import copy
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
from transformers.integrations import unfuse_modules


model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
tokenizer = AutoTokenizer.from_pretrained(model_id)

kernel_config = KernelConfig({
    (
        ("RMSNorm", "model.layers.*.post_attention_layernorm"),
        ("MLP",     "model.layers.*.mlp"),
    ): "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP",
})

model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda")

input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device)

original_model = copy.deepcopy(model)
unfuse_modules(original_model)
original_model.eval()

with torch.no_grad():
    fused_out = model(input_ids).logits
    original_out = original_model(input_ids).logits

print("Max diff fused vs original:", (fused_out - original_out).abs().max().item())

Question

How is related to #45041 and is it serving the same purpose / needs?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread fused_qwen_example.py
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed before merging. Temporary file.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3

@michaelbenayoun michaelbenayoun changed the title Fused kernels support n-to-1 kernel fusion via KernelConfig Apr 13, 2026
_FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns


class FusedModuleBase(nn.Module):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker the API is way simpler than #44979 because the whole purpose of fusion here is to replace the forward with a kernel. So we do not need all the complex machinery. What we need is just a module and being able to replace its forward with the kernel forward.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much simpler, much better! I like it 🤗 (not an in depth review as we have to discuss some stuff internally!)

Comment on lines +665 to +667
module.add_module(child_names[0], fused_instance)
for child_name in child_names[1:]:
module.add_module(child_name, nn.Identity())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're probably gonna have concerns with hooks, especially with TP potentially but also accelerate!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants