From b3871906e6cc244bcd7a7a183c311db36ec21c14 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 10 Apr 2026 14:52:10 -0400 Subject: [PATCH 1/8] feat: module fusion API for kernels --- src/transformers/module_fusion.py | 168 ++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/transformers/module_fusion.py diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py new file mode 100644 index 000000000000..d50576155a8d --- /dev/null +++ b/src/transformers/module_fusion.py @@ -0,0 +1,168 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import re + +from .utils import is_torch_available + + +if is_torch_available(): + import torch.nn as nn + + +class FusedModuleBase(nn.Module): + def __init__(self, modules_to_fuse: list[nn.Module], source_names: list[str]): + super().__init__() + if len(modules_to_fuse) == 0: + raise ValueError("At least one module must be provided for fusion.") + if len(modules_to_fuse) != len(source_names): + raise ValueError("Length of modules_to_fuse and source_names must match.") + + self._source_names = source_names + + for module in modules_to_fuse: + attr_name = getattr(module, "kernel_layer_name", None) + if attr_name is None: + raise ValueError(f"Module {module} does not have a 'kernel_layer_name' attribute.") + self.add_module(attr_name, module) + + self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] + + # `kernelize` validates the kernel's forward signature against the class being replaced. + # Since the fused container sits at the position of the first module in the chain, the + # kernel's forward must match that module's signature. We patch the class-level forward + # here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it. + # The body raises because this forward is always replaced by the kernel before any call. + @functools.wraps(type(modules_to_fuse[0]).forward) + def forward(self, *args, **kwargs): + raise NotImplementedError("FusedModule is a placeholder and should not be called directly.") + + self.__class__.forward = forward + + def __repr__(self): + names = ", ".join(self._fused_module_names) + return f"{self.__class__.__name__}(fused={names})" + + +@functools.cache +def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type: + """ + Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination. + + Args: + source_layer_names (`tuple[str, ...]`): + Ordered tuple of `kernel_layer_name` values of the modules being fused + (e.g. ``("RMSNorm", "MLP")``). Used as the cache key — the same combination + always returns the same class object. + kernel_layer_name (`str`): + The name assigned to the fused class, used by `kernelize` to look up the + kernel in the mapping (e.g. ``"RMSNormMLP"``). + + Returns: + A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. + """ + return type( + f"Fused_{'_'.join(source_layer_names)}", + (FusedModuleBase,), + {"kernel_layer_name": kernel_layer_name}, + ) + + +def fuse_modules( + model: nn.Module, + module_names_to_fuse: list[str], + kernel_layer_name: str, +) -> None: + """ + Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. + + For every parent module whose immediate children match all entries in + ``module_names_to_fuse``, the function: + + - replaces the first module with a `FusedModuleBase` subclass instance that holds + all source modules as named children, + - replaces the remaining modules with `nn.Identity()` pass-throughs. + + The fused container's ``forward`` signature is patched to match the first source + module's ``forward``, satisfying the ``kernelize`` signature check. + + Args: + model (`nn.Module`): + The model to modify in-place. + module_names_to_fuse (`list[str]`): + Glob-style paths of the modules to fuse, e.g. + ``["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]``. + Integer indices are replaced with ``*`` so the same pattern applies to + every repeated block. + kernel_layer_name (`str`): + The ``kernel_layer_name`` assigned to the fused class, used by ``kernelize`` + to look up the kernel in the mapping (e.g. ``"RMSNormMLP"``). + + Example:: + + fuse_modules( + model, + ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], + "RMSNormMLP", + ) + """ + pattern = re.compile(r"\d+") + for module_name, module in model.named_modules(): + generic_children = { + re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) + for n, child in module.named_children() + } + if not all(p in generic_children for p in module_names_to_fuse): + continue + + child_names = [generic_children[p][0] for p in module_names_to_fuse] + modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] + + source_layer_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) + FusedClass = make_fused_module_class(source_layer_names, kernel_layer_name) + fused_instance = FusedClass(modules_to_fuse, child_names) + + module.add_module(child_names[0], fused_instance) + for child_name in child_names[1:]: + module.add_module(child_name, nn.Identity()) + + +def unfuse_modules(model: nn.Module) -> None: + """ + Revert a previous `fuse_modules` call in-place, restoring the original modules. + + For each `FusedModuleBase` instance found in the model tree, the function: + + - restores the original first module at the fused container's position, + - restores the remaining original modules at their original positions + (replacing the `nn.Identity()` pass-throughs). + + Args: + model (`nn.Module`): The model to restore in-place. + + Example:: + + fuse_modules(model, ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], "RMSNormMLP") + # ... kernelized forward pass ... + unfuse_modules(model) # back to original + """ + for parent in model.modules(): + for name, child in list(parent.named_children()): + if not isinstance(child, FusedModuleBase): + continue + orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names] + parent.add_module(name, orig_modules[0]) + for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]): + parent.add_module(sibling_name, orig_module) From 6bc9402140c508bf15cf7958b6874744494bb1df Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 10 Apr 2026 15:11:47 -0400 Subject: [PATCH 2/8] fix: improve __repr__ for fused modules --- src/transformers/module_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py index d50576155a8d..a370b496a94c 100644 --- a/src/transformers/module_fusion.py +++ b/src/transformers/module_fusion.py @@ -53,7 +53,7 @@ def forward(self, *args, **kwargs): def __repr__(self): names = ", ".join(self._fused_module_names) - return f"{self.__class__.__name__}(fused={names})" + return f"{self.__class__.__name__}(fused=({names}))" @functools.cache From 62d4454588f15f2c8b45d4059d0f50a3c1740d4d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 10 Apr 2026 16:15:05 -0400 Subject: [PATCH 3/8] wip: integration to KernelConfig --- src/transformers/modeling_utils.py | 4 ++ .../models/qwen3/modeling_qwen3.py | 4 ++ src/transformers/utils/kernel_config.py | 39 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 27fcc3eaae1b..10d6d86f5c8d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3662,6 +3662,10 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None register_kernel_mapping_transformers() if kernel_config is not None and isinstance(kernel_config, KernelConfig): + # For n-to-1 entries (tuple keys), fuse the corresponding modules in the model + # before validation so that FusedModuleBase instances are visible to sanitize. + kernel_config.apply_fusions(self) + # This will make sure the mapping is valid, and the layers are registered in the model kernel_config.sanitize_kernel_mapping(self) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 91715a33cf9d..63ecfab0eff1 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -67,6 +67,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_kernel_forward_from_hub("MLP") class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() @@ -443,6 +444,9 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _kernel_fusion_patterns = { + "RMSNormMLP": ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index bb4f965ddbf4..fb45e085a431 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -116,6 +116,45 @@ def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revi } } + def apply_fusions(self, model): + """ + For each n-to-1 entry (tuple key) in the kernel mapping, find the fusion patterns + registered on the model, fuse the corresponding modules in-place, then replace the + tuple key with the resolved kernel layer name so the rest of the pipeline is unchanged. + """ + from ..module_fusion import fuse_modules + + new_mapping = {} + for layer_name, kernel in self.kernel_mapping.items(): + if not isinstance(layer_name, tuple): + new_mapping[layer_name] = kernel + continue + + # Parse the target kernel layer name from the repo string (the part after ':') + repo_str = kernel if isinstance(kernel, str) else next(iter(kernel.values())) + if isinstance(repo_str, dict): + repo_str = next(iter(repo_str.values())) + kernel_layer_name = repo_str.split(":")[1] if ":" in repo_str else repo_str.split("/")[-1] + + # Only fuse if a kernel is available for the current device + current_device = infer_device(model) + has_kernel_for_device = isinstance(kernel, str) or current_device in kernel + if not has_kernel_for_device: + continue + + # Look up fusion patterns registered on the model class + fusion_patterns = getattr(model, "_kernel_fusion_patterns", {}) + if kernel_layer_name not in fusion_patterns: + raise ValueError( + f"{type(model).__name__} does not define fusion patterns for '{kernel_layer_name}'. " + f'Add `_kernel_fusion_patterns = {{"{kernel_layer_name}": [...]}}` to the model class.' + ) + + fuse_modules(model, fusion_patterns[kernel_layer_name], kernel_layer_name) + new_mapping[kernel_layer_name] = kernel + + self.kernel_mapping = new_mapping + def store_registered_layer_names(self, model): for name, module in model.named_modules(): if hasattr(module, "kernel_layer_name"): From 4082fe15e78f36160d311cc4baa8bc804f274568 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 10 Apr 2026 16:19:18 -0400 Subject: [PATCH 4/8] wip: add temporary example --- fused_qwen_example.py | 65 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 fused_qwen_example.py diff --git a/fused_qwen_example.py b/fused_qwen_example.py new file mode 100644 index 000000000000..58467cbe3f62 --- /dev/null +++ b/fused_qwen_example.py @@ -0,0 +1,65 @@ +import copy +import torch +import torch.nn as nn + +from kernels import Mode, register_kernel_mapping + +from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig +from transformers.module_fusion import unfuse_modules + + +model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" +tokenizer = AutoTokenizer.from_pretrained(model_id) + + +class FakeRMSNormMLP(nn.Module): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + print("Using fake RMSNormMLP kernel") + hidden_states = self.RMSNorm(hidden_states) + hidden_states = self.MLP(hidden_states) + return hidden_states + + +class _InMemoryRepo: + """Minimal fake repository that returns a local class instead of downloading from the hub.""" + + def __init__(self, layer_cls: type): + self.layer_name = layer_cls.__name__ + self._layer_cls = layer_cls + + def load(self) -> type: + return self._layer_cls + + def __hash__(self): + return hash(self._layer_cls) + + def __eq__(self, other): + return isinstance(other, _InMemoryRepo) and self._layer_cls is other._layer_cls + + +# In production this would be a real hub repo string e.g. "kernels-community/rmsnorm-mlp:RMSNormMLP". +# For testing we pre-register a fake in-memory kernel so no hub download is needed. +register_kernel_mapping({ + "RMSNormMLP": { + "cuda": {Mode.INFERENCE: _InMemoryRepo(FakeRMSNormMLP)}, + } +}) + +kernel_config = KernelConfig({ + ("RMSNorm", "MLP"): "fake/repo:RMSNormMLP", +}) + +model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda") +model.eval() + +input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device) + +original_model = copy.deepcopy(model) +unfuse_modules(original_model) +original_model.eval() + +with torch.no_grad(): + fused_out = model(input_ids).logits + original_out = original_model(input_ids).logits + +print("Max diff fused vs original:", (fused_out - original_out).abs().max().item()) From ac4a6994a35cc054c557909546434110014de0dc Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 13 Apr 2026 15:15:56 -0400 Subject: [PATCH 5/8] wip: pattern matching in KernelConfig and actual kernel repo --- fused_qwen_example.py | 43 +------- .../models/qwen3/modeling_qwen3.py | 4 - src/transformers/module_fusion.py | 101 ++++++++++++++++-- src/transformers/utils/kernel_config.py | 30 ++++-- 4 files changed, 116 insertions(+), 62 deletions(-) diff --git a/fused_qwen_example.py b/fused_qwen_example.py index 58467cbe3f62..33a00915811a 100644 --- a/fused_qwen_example.py +++ b/fused_qwen_example.py @@ -1,8 +1,5 @@ import copy import torch -import torch.nn as nn - -from kernels import Mode, register_kernel_mapping from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig from transformers.module_fusion import unfuse_modules @@ -11,46 +8,14 @@ model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" tokenizer = AutoTokenizer.from_pretrained(model_id) - -class FakeRMSNormMLP(nn.Module): - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - print("Using fake RMSNormMLP kernel") - hidden_states = self.RMSNorm(hidden_states) - hidden_states = self.MLP(hidden_states) - return hidden_states - - -class _InMemoryRepo: - """Minimal fake repository that returns a local class instead of downloading from the hub.""" - - def __init__(self, layer_cls: type): - self.layer_name = layer_cls.__name__ - self._layer_cls = layer_cls - - def load(self) -> type: - return self._layer_cls - - def __hash__(self): - return hash(self._layer_cls) - - def __eq__(self, other): - return isinstance(other, _InMemoryRepo) and self._layer_cls is other._layer_cls - - -# In production this would be a real hub repo string e.g. "kernels-community/rmsnorm-mlp:RMSNormMLP". -# For testing we pre-register a fake in-memory kernel so no hub download is needed. -register_kernel_mapping({ - "RMSNormMLP": { - "cuda": {Mode.INFERENCE: _InMemoryRepo(FakeRMSNormMLP)}, - } -}) - kernel_config = KernelConfig({ - ("RMSNorm", "MLP"): "fake/repo:RMSNormMLP", + ( + ("RMSNorm", "model.layers.*.post_attention_layernorm"), + ("MLP", "model.layers.*.mlp"), + ): "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP", }) model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda") -model.eval() input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 63ecfab0eff1..91715a33cf9d 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -67,7 +67,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_kernel_forward_from_hub("MLP") class Qwen3MLP(nn.Module): def __init__(self, config): super().__init__() @@ -444,9 +443,6 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_gather_output"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - _kernel_fusion_patterns = { - "RMSNormMLP": ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], - } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py index a370b496a94c..c0bcf76ee427 100644 --- a/src/transformers/module_fusion.py +++ b/src/transformers/module_fusion.py @@ -22,8 +22,72 @@ import torch.nn as nn +# Module-level registry: model class → {kernel_layer_name → [glob patterns]} +# Populated via `register_fusion_patterns` for models that cannot be modified directly. +_FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} + + +def register_fusion_patterns( + model_class_or_instance, + patterns: dict[str, list[str]], +) -> None: + """ + Register kernel fusion patterns for a model class without modifying it directly. + + This is an alternative to setting ``_kernel_fusion_patterns`` as a class attribute, + useful when the model class is frozen or comes from an external library. + + Args: + model_class_or_instance: + The model class (or an instance of it) for which patterns are being registered. + patterns (`dict[str, list[str]]`): + Mapping from ``kernel_layer_name`` to a list of glob-style module paths, + identical in format to ``_kernel_fusion_patterns``. For example:: + + { + "RMSNormMLP": [ + "model.layers.*.post_attention_layernorm", + "model.layers.*.mlp", + ] + } + + Example:: + + from transformers.module_fusion import register_fusion_patterns + from transformers.models.qwen3 import Qwen3ForCausalLM + + register_fusion_patterns( + Qwen3ForCausalLM, + { + "RMSNormMLP": [ + "model.layers.*.post_attention_layernorm", + "model.layers.*.mlp", + ] + }, + ) + """ + if not isinstance(model_class_or_instance, type): + model_class_or_instance = type(model_class_or_instance) + _FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns + + class FusedModuleBase(nn.Module): - def __init__(self, modules_to_fuse: list[nn.Module], source_names: list[str]): + def __init__( + self, + modules_to_fuse: list[nn.Module], + source_names: list[str], + fused_module_names: list[str] | None = None, + ): + """ + Args: + modules_to_fuse: The source modules to fuse together. + source_names: The attribute names under which each module lives in its parent + (used to restore them on ``unfuse_modules``). + fused_module_names: The names under which each source module is registered as a + child of this container (i.e. ``self.``). When ``None``, the + ``kernel_layer_name`` attribute of each source module is used. Pass this + explicitly when the source modules do not carry ``@use_kernel_forward_from_hub``. + """ super().__init__() if len(modules_to_fuse) == 0: raise ValueError("At least one module must be provided for fusion.") @@ -32,13 +96,24 @@ def __init__(self, modules_to_fuse: list[nn.Module], source_names: list[str]): self._source_names = source_names - for module in modules_to_fuse: - attr_name = getattr(module, "kernel_layer_name", None) - if attr_name is None: - raise ValueError(f"Module {module} does not have a 'kernel_layer_name' attribute.") - self.add_module(attr_name, module) - - self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] + if fused_module_names is not None: + if len(fused_module_names) != len(modules_to_fuse): + raise ValueError("Length of fused_module_names and modules_to_fuse must match.") + for module, name in zip(modules_to_fuse, fused_module_names): + self.add_module(name, module) + self._fused_module_names = list(fused_module_names) + else: + for module in modules_to_fuse: + attr_name = getattr(module, "kernel_layer_name", None) + if attr_name is None: + raise ValueError( + f"Module {module} does not have a 'kernel_layer_name' attribute. " + f"Either decorate it with @use_kernel_forward_from_hub or provide " + f"explicit names via the inline pattern format: " + f'(("", ""), ...).' + ) + self.add_module(attr_name, module) + self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] # `kernelize` validates the kernel's forward signature against the class being replaced. # Since the fused container sits at the position of the first module in the chain, the @@ -84,6 +159,7 @@ def fuse_modules( model: nn.Module, module_names_to_fuse: list[str], kernel_layer_name: str, + source_layer_names: list[str] | None = None, ) -> None: """ Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. @@ -130,9 +206,12 @@ def fuse_modules( child_names = [generic_children[p][0] for p in module_names_to_fuse] modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] - source_layer_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) - FusedClass = make_fused_module_class(source_layer_names, kernel_layer_name) - fused_instance = FusedClass(modules_to_fuse, child_names) + if source_layer_names is not None: + resolved_names = tuple(source_layer_names) + else: + resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) + FusedClass = make_fused_module_class(resolved_names, kernel_layer_name) + fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names)) module.add_module(child_names[0], fused_instance) for child_name in child_names[1:]: diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index fb45e085a431..8f66411480c0 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -142,15 +142,29 @@ def apply_fusions(self, model): if not has_kernel_for_device: continue - # Look up fusion patterns registered on the model class - fusion_patterns = getattr(model, "_kernel_fusion_patterns", {}) - if kernel_layer_name not in fusion_patterns: - raise ValueError( - f"{type(model).__name__} does not define fusion patterns for '{kernel_layer_name}'. " - f'Add `_kernel_fusion_patterns = {{"{kernel_layer_name}": [...]}}` to the model class.' - ) + # Detect inline format: tuple of (kernel_layer_name, glob_pattern) pairs. + # e.g. (("RMSNorm", "model.layers.*.post_attention_layernorm"), ("MLP", "model.layers.*.mlp")) + is_inline = all(isinstance(item, tuple) and len(item) == 2 for item in layer_name) + if is_inline: + source_names = [item[0] for item in layer_name] + patterns = [item[1] for item in layer_name] + fuse_modules(model, patterns, kernel_layer_name, source_layer_names=source_names) + else: + # Legacy format: ("RMSNorm", "MLP") — look up patterns from model class or registry. + from ..module_fusion import _FUSION_PATTERNS_REGISTRY - fuse_modules(model, fusion_patterns[kernel_layer_name], kernel_layer_name) + fusion_patterns = getattr(model, "_kernel_fusion_patterns", None) or _FUSION_PATTERNS_REGISTRY.get( + type(model), {} + ) + if kernel_layer_name not in fusion_patterns: + raise ValueError( + f"{type(model).__name__} does not define fusion patterns for '{kernel_layer_name}'. " + f'Either add `_kernel_fusion_patterns = {{"{kernel_layer_name}": [...]}}` to the model class, ' + f"call `register_fusion_patterns({type(model).__name__}, ...)` before loading the model, " + f"or use the inline pattern format: " + f'(("{kernel_layer_name}", ""), ...).' + ) + fuse_modules(model, fusion_patterns[kernel_layer_name], kernel_layer_name) new_mapping[kernel_layer_name] = kernel self.kernel_mapping = new_mapping From e13111fbf358335fff766842d8374a96e9fb7ca9 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 13 Apr 2026 15:21:59 -0400 Subject: [PATCH 6/8] refactor: move relevant code to hub_kernels.py --- fused_qwen_example.py | 2 +- src/transformers/integrations/__init__.py | 10 + src/transformers/integrations/hub_kernels.py | 253 ++++++++++++++++++- src/transformers/module_fusion.py | 247 ------------------ src/transformers/utils/kernel_config.py | 5 +- 5 files changed, 260 insertions(+), 257 deletions(-) delete mode 100644 src/transformers/module_fusion.py diff --git a/fused_qwen_example.py b/fused_qwen_example.py index 33a00915811a..0b7abd98439c 100644 --- a/fused_qwen_example.py +++ b/fused_qwen_example.py @@ -2,7 +2,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig -from transformers.module_fusion import unfuse_modules +from transformers.integrations import unfuse_modules model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 336db3773f76..f74836d890f7 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -67,10 +67,15 @@ ], "hqq": ["prepare_for_hqq_linear"], "hub_kernels": [ + "FusedModuleBase", "LayerRepository", + "fuse_modules", "lazy_load_kernel", + "make_fused_module_class", + "register_fusion_patterns", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "unfuse_modules", "use_kernel_forward_from_hub", "use_kernel_func_from_hub", "use_kernelized_func", @@ -221,10 +226,15 @@ from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear from .hqq import prepare_for_hqq_linear from .hub_kernels import ( + FusedModuleBase, LayerRepository, + fuse_modules, lazy_load_kernel, + make_fused_module_class, + register_fusion_patterns, register_kernel_mapping, replace_kernel_forward_from_hub, + unfuse_modules, use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 39c1448af02b..fb5f2064917b 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import importlib.metadata import os import re @@ -21,10 +22,14 @@ from packaging import version as pkg_version from ..utils import ENV_VARS_TRUE_VALUES, logging -from ..utils.import_utils import is_kernels_available +from ..utils.import_utils import is_kernels_available, is_torch_available from .flash_attention import flash_attention_forward +if is_torch_available(): + import torch.nn as nn + + logger = logging.get_logger(__name__) try: @@ -483,14 +488,252 @@ def allow_all_hub_kernels(): ALLOW_ALL_KERNELS = False +# --------------------------------------------------------------------------- +# Module fusion helpers +# --------------------------------------------------------------------------- + +# Model class → {kernel_layer_name → [glob patterns]} +# Populated via `register_fusion_patterns` for models that cannot be modified directly. +_FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} + + +def register_fusion_patterns( + model_class_or_instance, + patterns: dict[str, list[str]], +) -> None: + """ + Register kernel fusion patterns for a model class without modifying it directly. + + This is an alternative to setting ``_kernel_fusion_patterns`` as a class attribute, + useful when the model class is frozen or comes from an external library. + + Args: + model_class_or_instance: + The model class (or an instance of it) for which patterns are being registered. + patterns (`dict[str, list[str]]`): + Mapping from ``kernel_layer_name`` to a list of glob-style module paths, + identical in format to ``_kernel_fusion_patterns``. For example:: + + { + "RMSNormMLP": [ + "model.layers.*.post_attention_layernorm", + "model.layers.*.mlp", + ] + } + + Example:: + + from transformers.integrations import register_fusion_patterns + from transformers.models.qwen3 import Qwen3ForCausalLM + + register_fusion_patterns( + Qwen3ForCausalLM, + { + "RMSNormMLP": [ + "model.layers.*.post_attention_layernorm", + "model.layers.*.mlp", + ] + }, + ) + """ + if not isinstance(model_class_or_instance, type): + model_class_or_instance = type(model_class_or_instance) + _FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns + + +class FusedModuleBase(nn.Module): + def __init__( + self, + modules_to_fuse: list["nn.Module"], + source_names: list[str], + fused_module_names: list[str] | None = None, + ): + """ + Args: + modules_to_fuse: The source modules to fuse together. + source_names: The attribute names under which each module lives in its parent + (used to restore them on ``unfuse_modules``). + fused_module_names: The names under which each source module is registered as a + child of this container (i.e. ``self.``). When ``None``, the + ``kernel_layer_name`` attribute of each source module is used. Pass this + explicitly when the source modules do not carry ``@use_kernel_forward_from_hub``. + """ + super().__init__() + if len(modules_to_fuse) == 0: + raise ValueError("At least one module must be provided for fusion.") + if len(modules_to_fuse) != len(source_names): + raise ValueError("Length of modules_to_fuse and source_names must match.") + + self._source_names = source_names + + if fused_module_names is not None: + if len(fused_module_names) != len(modules_to_fuse): + raise ValueError("Length of fused_module_names and modules_to_fuse must match.") + for module, name in zip(modules_to_fuse, fused_module_names): + self.add_module(name, module) + self._fused_module_names = list(fused_module_names) + else: + for module in modules_to_fuse: + attr_name = getattr(module, "kernel_layer_name", None) + if attr_name is None: + raise ValueError( + f"Module {module} does not have a 'kernel_layer_name' attribute. " + f"Either decorate it with @use_kernel_forward_from_hub or provide " + f"explicit names via the inline pattern format: " + f'(("", ""), ...).' + ) + self.add_module(attr_name, module) + self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] + + # `kernelize` validates the kernel's forward signature against the class being replaced. + # Since the fused container sits at the position of the first module in the chain, the + # kernel's forward must match that module's signature. We patch the class-level forward + # here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it. + # The body raises because this forward is always replaced by the kernel before any call. + @functools.wraps(type(modules_to_fuse[0]).forward) + def forward(self, *args, **kwargs): + raise NotImplementedError("FusedModule is a placeholder and should not be called directly.") + + self.__class__.forward = forward + + def __repr__(self): + names = ", ".join(self._fused_module_names) + return f"{self.__class__.__name__}(fused=({names}))" + + +@functools.cache +def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type: + """ + Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination. + + Args: + source_layer_names (`tuple[str, ...]`): + Ordered tuple of `kernel_layer_name` values of the modules being fused + (e.g. ``("RMSNorm", "MLP")``). Used as the cache key — the same combination + always returns the same class object. + kernel_layer_name (`str`): + The name assigned to the fused class, used by `kernelize` to look up the + kernel in the mapping (e.g. ``"RMSNormMLP"``). + + Returns: + A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. + """ + return type( + f"Fused_{'_'.join(source_layer_names)}", + (FusedModuleBase,), + {"kernel_layer_name": kernel_layer_name}, + ) + + +def fuse_modules( + model: "nn.Module", + module_names_to_fuse: list[str], + kernel_layer_name: str, + source_layer_names: list[str] | None = None, +) -> None: + """ + Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. + + For every parent module whose immediate children match all entries in + ``module_names_to_fuse``, the function: + + - replaces the first module with a `FusedModuleBase` subclass instance that holds + all source modules as named children, + - replaces the remaining modules with `nn.Identity()` pass-throughs. + + The fused container's ``forward`` signature is patched to match the first source + module's ``forward``, satisfying the ``kernelize`` signature check. + + Args: + model (`nn.Module`): + The model to modify in-place. + module_names_to_fuse (`list[str]`): + Glob-style paths of the modules to fuse, e.g. + ``["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]``. + Integer indices are replaced with ``*`` so the same pattern applies to + every repeated block. + kernel_layer_name (`str`): + The ``kernel_layer_name`` assigned to the fused class, used by ``kernelize`` + to look up the kernel in the mapping (e.g. ``"RMSNormMLP"``). + source_layer_names (`list[str]`, *optional*): + Explicit names for the child modules inside the fused container + (e.g. ``["RMSNorm", "MLP"]``). When ``None``, the ``kernel_layer_name`` + attribute of each source module is used. + + Example:: + + fuse_modules( + model, + ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], + "RMSNormMLP", + ) + """ + pattern = re.compile(r"\d+") + for module_name, module in model.named_modules(): + generic_children = { + re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) + for n, child in module.named_children() + } + if not all(p in generic_children for p in module_names_to_fuse): + continue + + child_names = [generic_children[p][0] for p in module_names_to_fuse] + modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] + + if source_layer_names is not None: + resolved_names = tuple(source_layer_names) + else: + resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) + FusedClass = make_fused_module_class(resolved_names, kernel_layer_name) + fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names)) + + module.add_module(child_names[0], fused_instance) + for child_name in child_names[1:]: + module.add_module(child_name, nn.Identity()) + + +def unfuse_modules(model: "nn.Module") -> None: + """ + Revert a previous `fuse_modules` call in-place, restoring the original modules. + + For each `FusedModuleBase` instance found in the model tree, the function: + + - restores the original first module at the fused container's position, + - restores the remaining original modules at their original positions + (replacing the `nn.Identity()` pass-throughs). + + Args: + model (`nn.Module`): The model to restore in-place. + + Example:: + + fuse_modules(model, ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], "RMSNormMLP") + # ... kernelized forward pass ... + unfuse_modules(model) # back to original + """ + for parent in model.modules(): + for name, child in list(parent.named_children()): + if not isinstance(child, FusedModuleBase): + continue + orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names] + parent.add_module(name, orig_modules[0]) + for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]): + parent.add_module(sibling_name, orig_module) + + __all__ = [ + "FusedModuleBase", "LayerRepository", - "use_kernel_forward_from_hub", - "use_kernel_func_from_hub", + "fuse_modules", + "get_kernel", + "lazy_load_kernel", + "make_fused_module_class", + "register_fusion_patterns", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", - "lazy_load_kernel", - "get_kernel", + "unfuse_modules", + "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "use_kernelized_func", ] # type: ignore diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py deleted file mode 100644 index c0bcf76ee427..000000000000 --- a/src/transformers/module_fusion.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import re - -from .utils import is_torch_available - - -if is_torch_available(): - import torch.nn as nn - - -# Module-level registry: model class → {kernel_layer_name → [glob patterns]} -# Populated via `register_fusion_patterns` for models that cannot be modified directly. -_FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} - - -def register_fusion_patterns( - model_class_or_instance, - patterns: dict[str, list[str]], -) -> None: - """ - Register kernel fusion patterns for a model class without modifying it directly. - - This is an alternative to setting ``_kernel_fusion_patterns`` as a class attribute, - useful when the model class is frozen or comes from an external library. - - Args: - model_class_or_instance: - The model class (or an instance of it) for which patterns are being registered. - patterns (`dict[str, list[str]]`): - Mapping from ``kernel_layer_name`` to a list of glob-style module paths, - identical in format to ``_kernel_fusion_patterns``. For example:: - - { - "RMSNormMLP": [ - "model.layers.*.post_attention_layernorm", - "model.layers.*.mlp", - ] - } - - Example:: - - from transformers.module_fusion import register_fusion_patterns - from transformers.models.qwen3 import Qwen3ForCausalLM - - register_fusion_patterns( - Qwen3ForCausalLM, - { - "RMSNormMLP": [ - "model.layers.*.post_attention_layernorm", - "model.layers.*.mlp", - ] - }, - ) - """ - if not isinstance(model_class_or_instance, type): - model_class_or_instance = type(model_class_or_instance) - _FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns - - -class FusedModuleBase(nn.Module): - def __init__( - self, - modules_to_fuse: list[nn.Module], - source_names: list[str], - fused_module_names: list[str] | None = None, - ): - """ - Args: - modules_to_fuse: The source modules to fuse together. - source_names: The attribute names under which each module lives in its parent - (used to restore them on ``unfuse_modules``). - fused_module_names: The names under which each source module is registered as a - child of this container (i.e. ``self.``). When ``None``, the - ``kernel_layer_name`` attribute of each source module is used. Pass this - explicitly when the source modules do not carry ``@use_kernel_forward_from_hub``. - """ - super().__init__() - if len(modules_to_fuse) == 0: - raise ValueError("At least one module must be provided for fusion.") - if len(modules_to_fuse) != len(source_names): - raise ValueError("Length of modules_to_fuse and source_names must match.") - - self._source_names = source_names - - if fused_module_names is not None: - if len(fused_module_names) != len(modules_to_fuse): - raise ValueError("Length of fused_module_names and modules_to_fuse must match.") - for module, name in zip(modules_to_fuse, fused_module_names): - self.add_module(name, module) - self._fused_module_names = list(fused_module_names) - else: - for module in modules_to_fuse: - attr_name = getattr(module, "kernel_layer_name", None) - if attr_name is None: - raise ValueError( - f"Module {module} does not have a 'kernel_layer_name' attribute. " - f"Either decorate it with @use_kernel_forward_from_hub or provide " - f"explicit names via the inline pattern format: " - f'(("", ""), ...).' - ) - self.add_module(attr_name, module) - self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] - - # `kernelize` validates the kernel's forward signature against the class being replaced. - # Since the fused container sits at the position of the first module in the chain, the - # kernel's forward must match that module's signature. We patch the class-level forward - # here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it. - # The body raises because this forward is always replaced by the kernel before any call. - @functools.wraps(type(modules_to_fuse[0]).forward) - def forward(self, *args, **kwargs): - raise NotImplementedError("FusedModule is a placeholder and should not be called directly.") - - self.__class__.forward = forward - - def __repr__(self): - names = ", ".join(self._fused_module_names) - return f"{self.__class__.__name__}(fused=({names}))" - - -@functools.cache -def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type: - """ - Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination. - - Args: - source_layer_names (`tuple[str, ...]`): - Ordered tuple of `kernel_layer_name` values of the modules being fused - (e.g. ``("RMSNorm", "MLP")``). Used as the cache key — the same combination - always returns the same class object. - kernel_layer_name (`str`): - The name assigned to the fused class, used by `kernelize` to look up the - kernel in the mapping (e.g. ``"RMSNormMLP"``). - - Returns: - A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. - """ - return type( - f"Fused_{'_'.join(source_layer_names)}", - (FusedModuleBase,), - {"kernel_layer_name": kernel_layer_name}, - ) - - -def fuse_modules( - model: nn.Module, - module_names_to_fuse: list[str], - kernel_layer_name: str, - source_layer_names: list[str] | None = None, -) -> None: - """ - Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. - - For every parent module whose immediate children match all entries in - ``module_names_to_fuse``, the function: - - - replaces the first module with a `FusedModuleBase` subclass instance that holds - all source modules as named children, - - replaces the remaining modules with `nn.Identity()` pass-throughs. - - The fused container's ``forward`` signature is patched to match the first source - module's ``forward``, satisfying the ``kernelize`` signature check. - - Args: - model (`nn.Module`): - The model to modify in-place. - module_names_to_fuse (`list[str]`): - Glob-style paths of the modules to fuse, e.g. - ``["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]``. - Integer indices are replaced with ``*`` so the same pattern applies to - every repeated block. - kernel_layer_name (`str`): - The ``kernel_layer_name`` assigned to the fused class, used by ``kernelize`` - to look up the kernel in the mapping (e.g. ``"RMSNormMLP"``). - - Example:: - - fuse_modules( - model, - ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], - "RMSNormMLP", - ) - """ - pattern = re.compile(r"\d+") - for module_name, module in model.named_modules(): - generic_children = { - re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) - for n, child in module.named_children() - } - if not all(p in generic_children for p in module_names_to_fuse): - continue - - child_names = [generic_children[p][0] for p in module_names_to_fuse] - modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] - - if source_layer_names is not None: - resolved_names = tuple(source_layer_names) - else: - resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) - FusedClass = make_fused_module_class(resolved_names, kernel_layer_name) - fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names)) - - module.add_module(child_names[0], fused_instance) - for child_name in child_names[1:]: - module.add_module(child_name, nn.Identity()) - - -def unfuse_modules(model: nn.Module) -> None: - """ - Revert a previous `fuse_modules` call in-place, restoring the original modules. - - For each `FusedModuleBase` instance found in the model tree, the function: - - - restores the original first module at the fused container's position, - - restores the remaining original modules at their original positions - (replacing the `nn.Identity()` pass-throughs). - - Args: - model (`nn.Module`): The model to restore in-place. - - Example:: - - fuse_modules(model, ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], "RMSNormMLP") - # ... kernelized forward pass ... - unfuse_modules(model) # back to original - """ - for parent in model.modules(): - for name, child in list(parent.named_children()): - if not isinstance(child, FusedModuleBase): - continue - orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names] - parent.add_module(name, orig_modules[0]) - for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]): - parent.add_module(sibling_name, orig_module) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index 8f66411480c0..1c7541fbcbc2 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..integrations.hub_kernels import _FUSION_PATTERNS_REGISTRY, fuse_modules from ..utils import PushToHubMixin @@ -122,8 +123,6 @@ def apply_fusions(self, model): registered on the model, fuse the corresponding modules in-place, then replace the tuple key with the resolved kernel layer name so the rest of the pipeline is unchanged. """ - from ..module_fusion import fuse_modules - new_mapping = {} for layer_name, kernel in self.kernel_mapping.items(): if not isinstance(layer_name, tuple): @@ -151,8 +150,6 @@ def apply_fusions(self, model): fuse_modules(model, patterns, kernel_layer_name, source_layer_names=source_names) else: # Legacy format: ("RMSNorm", "MLP") — look up patterns from model class or registry. - from ..module_fusion import _FUSION_PATTERNS_REGISTRY - fusion_patterns = getattr(model, "_kernel_fusion_patterns", None) or _FUSION_PATTERNS_REGISTRY.get( type(model), {} ) From d9d53f08e40d98f9acfbaa6e2814169b983e1651 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 13 Apr 2026 15:32:00 -0400 Subject: [PATCH 7/8] docs: reformat docstring --- src/transformers/integrations/hub_kernels.py | 74 +++++--------------- 1 file changed, 17 insertions(+), 57 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index fb5f2064917b..d1cc48c26093 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -488,10 +488,6 @@ def allow_all_hub_kernels(): ALLOW_ALL_KERNELS = False -# --------------------------------------------------------------------------- -# Module fusion helpers -# --------------------------------------------------------------------------- - # Model class → {kernel_layer_name → [glob patterns]} # Populated via `register_fusion_patterns` for models that cannot be modified directly. _FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} @@ -504,37 +500,15 @@ def register_fusion_patterns( """ Register kernel fusion patterns for a model class without modifying it directly. - This is an alternative to setting ``_kernel_fusion_patterns`` as a class attribute, + This is an alternative to setting `_kernel_fusion_patterns` as a class attribute, useful when the model class is frozen or comes from an external library. Args: model_class_or_instance: The model class (or an instance of it) for which patterns are being registered. patterns (`dict[str, list[str]]`): - Mapping from ``kernel_layer_name`` to a list of glob-style module paths, - identical in format to ``_kernel_fusion_patterns``. For example:: - - { - "RMSNormMLP": [ - "model.layers.*.post_attention_layernorm", - "model.layers.*.mlp", - ] - } - - Example:: - - from transformers.integrations import register_fusion_patterns - from transformers.models.qwen3 import Qwen3ForCausalLM - - register_fusion_patterns( - Qwen3ForCausalLM, - { - "RMSNormMLP": [ - "model.layers.*.post_attention_layernorm", - "model.layers.*.mlp", - ] - }, - ) + Mapping from `kernel_layer_name` to a list of glob-style module paths, + identical in format to `_kernel_fusion_patterns`. """ if not isinstance(model_class_or_instance, type): model_class_or_instance = type(model_class_or_instance) @@ -552,11 +526,11 @@ def __init__( Args: modules_to_fuse: The source modules to fuse together. source_names: The attribute names under which each module lives in its parent - (used to restore them on ``unfuse_modules``). + (used to restore them on `unfuse_modules`). fused_module_names: The names under which each source module is registered as a - child of this container (i.e. ``self.``). When ``None``, the - ``kernel_layer_name`` attribute of each source module is used. Pass this - explicitly when the source modules do not carry ``@use_kernel_forward_from_hub``. + child of this container (i.e. `self.`). When `None`, the + `kernel_layer_name` attribute of each source module is used. Pass this + explicitly when the source modules do not carry `@use_kernel_forward_from_hub`. """ super().__init__() if len(modules_to_fuse) == 0: @@ -609,11 +583,11 @@ def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_na Args: source_layer_names (`tuple[str, ...]`): Ordered tuple of `kernel_layer_name` values of the modules being fused - (e.g. ``("RMSNorm", "MLP")``). Used as the cache key — the same combination + (e.g. `("RMSNorm", "MLP")`). Used as the cache key — the same combination always returns the same class object. kernel_layer_name (`str`): The name assigned to the fused class, used by `kernelize` to look up the - kernel in the mapping (e.g. ``"RMSNormMLP"``). + kernel in the mapping (e.g. `"RMSNormMLP"`). Returns: A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. @@ -635,38 +609,30 @@ def fuse_modules( Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. For every parent module whose immediate children match all entries in - ``module_names_to_fuse``, the function: + `module_names_to_fuse`, the function: - replaces the first module with a `FusedModuleBase` subclass instance that holds all source modules as named children, - replaces the remaining modules with `nn.Identity()` pass-throughs. - The fused container's ``forward`` signature is patched to match the first source - module's ``forward``, satisfying the ``kernelize`` signature check. + The fused container's `forward` signature is patched to match the first source + module's `forward`, satisfying the `kernelize` signature check. Args: model (`nn.Module`): The model to modify in-place. module_names_to_fuse (`list[str]`): Glob-style paths of the modules to fuse, e.g. - ``["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]``. - Integer indices are replaced with ``*`` so the same pattern applies to + `["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]`. + Integer indices are replaced with `*` so the same pattern applies to every repeated block. kernel_layer_name (`str`): - The ``kernel_layer_name`` assigned to the fused class, used by ``kernelize`` - to look up the kernel in the mapping (e.g. ``"RMSNormMLP"``). + The `kernel_layer_name` assigned to the fused class, used by `kernelize` + to look up the kernel in the mapping (e.g. `"RMSNormMLP"`). source_layer_names (`list[str]`, *optional*): Explicit names for the child modules inside the fused container - (e.g. ``["RMSNorm", "MLP"]``). When ``None``, the ``kernel_layer_name`` + (e.g. `["RMSNorm", "MLP"]`). When `None`, the `kernel_layer_name` attribute of each source module is used. - - Example:: - - fuse_modules( - model, - ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], - "RMSNormMLP", - ) """ pattern = re.compile(r"\d+") for module_name, module in model.named_modules(): @@ -704,12 +670,6 @@ def unfuse_modules(model: "nn.Module") -> None: Args: model (`nn.Module`): The model to restore in-place. - - Example:: - - fuse_modules(model, ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], "RMSNormMLP") - # ... kernelized forward pass ... - unfuse_modules(model) # back to original """ for parent in model.modules(): for name, child in list(parent.named_children()): From e1c7f3f0bc4f3f702e2f7af43655f0dcfbab2d0a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 13 Apr 2026 15:40:12 -0400 Subject: [PATCH 8/8] refactor: remove comment --- src/transformers/utils/kernel_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index 1c7541fbcbc2..3a64a3808279 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -149,7 +149,6 @@ def apply_fusions(self, model): patterns = [item[1] for item in layer_name] fuse_modules(model, patterns, kernel_layer_name, source_layer_names=source_names) else: - # Legacy format: ("RMSNorm", "MLP") — look up patterns from model class or registry. fusion_patterns = getattr(model, "_kernel_fusion_patterns", None) or _FUSION_PATTERNS_REGISTRY.get( type(model), {} )