Skip to content
30 changes: 30 additions & 0 deletions fused_qwen_example.py
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed before merging. Temporary file.

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import copy
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
from transformers.integrations import unfuse_modules


model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
tokenizer = AutoTokenizer.from_pretrained(model_id)

kernel_config = KernelConfig({
(
("RMSNorm", "model.layers.*.post_attention_layernorm"),
("MLP", "model.layers.*.mlp"),
): "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP",
})

model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda")

input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device)

original_model = copy.deepcopy(model)
unfuse_modules(original_model)
original_model.eval()

with torch.no_grad():
fused_out = model(input_ids).logits
original_out = original_model(input_ids).logits

print("Max diff fused vs original:", (fused_out - original_out).abs().max().item())
10 changes: 10 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@
],
"hqq": ["prepare_for_hqq_linear"],
"hub_kernels": [
"FusedModuleBase",
"LayerRepository",
"fuse_modules",
"lazy_load_kernel",
"make_fused_module_class",
"register_fusion_patterns",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"unfuse_modules",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"use_kernelized_func",
Expand Down Expand Up @@ -221,10 +226,15 @@
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
from .hqq import prepare_for_hqq_linear
from .hub_kernels import (
FusedModuleBase,
LayerRepository,
fuse_modules,
lazy_load_kernel,
make_fused_module_class,
register_fusion_patterns,
register_kernel_mapping,
replace_kernel_forward_from_hub,
unfuse_modules,
use_kernel_forward_from_hub,
use_kernel_func_from_hub,
use_kernelized_func,
Expand Down
213 changes: 208 additions & 5 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib.metadata
import os
import re
Expand All @@ -21,10 +22,14 @@
from packaging import version as pkg_version

from ..utils import ENV_VARS_TRUE_VALUES, logging
from ..utils.import_utils import is_kernels_available
from ..utils.import_utils import is_kernels_available, is_torch_available
from .flash_attention import flash_attention_forward


if is_torch_available():
import torch.nn as nn


logger = logging.get_logger(__name__)

try:
Expand Down Expand Up @@ -500,14 +505,212 @@ def allow_all_hub_kernels():
ALLOW_ALL_KERNELS = False


# Model class → {kernel_layer_name → [glob patterns]}
# Populated via `register_fusion_patterns` for models that cannot be modified directly.
_FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {}


def register_fusion_patterns(
model_class_or_instance,
patterns: dict[str, list[str]],
) -> None:
"""
Register kernel fusion patterns for a model class without modifying it directly.

This is an alternative to setting `_kernel_fusion_patterns` as a class attribute,
useful when the model class is frozen or comes from an external library.

Args:
model_class_or_instance:
The model class (or an instance of it) for which patterns are being registered.
patterns (`dict[str, list[str]]`):
Mapping from `kernel_layer_name` to a list of glob-style module paths,
identical in format to `_kernel_fusion_patterns`.
"""
if not isinstance(model_class_or_instance, type):
model_class_or_instance = type(model_class_or_instance)
_FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns


class FusedModuleBase(nn.Module):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker the API is way simpler than #44979 because the whole purpose of fusion here is to replace the forward with a kernel. So we do not need all the complex machinery. What we need is just a module and being able to replace its forward with the kernel forward.

def __init__(
self,
modules_to_fuse: list["nn.Module"],
source_names: list[str],
fused_module_names: list[str] | None = None,
):
"""
Args:
modules_to_fuse: The source modules to fuse together.
source_names: The attribute names under which each module lives in its parent
(used to restore them on `unfuse_modules`).
fused_module_names: The names under which each source module is registered as a
child of this container (i.e. `self.<name>`). When `None`, the
`kernel_layer_name` attribute of each source module is used. Pass this
explicitly when the source modules do not carry `@use_kernel_forward_from_hub`.
"""
super().__init__()
if len(modules_to_fuse) == 0:
raise ValueError("At least one module must be provided for fusion.")
if len(modules_to_fuse) != len(source_names):
raise ValueError("Length of modules_to_fuse and source_names must match.")

self._source_names = source_names

if fused_module_names is not None:
if len(fused_module_names) != len(modules_to_fuse):
raise ValueError("Length of fused_module_names and modules_to_fuse must match.")
for module, name in zip(modules_to_fuse, fused_module_names):
self.add_module(name, module)
self._fused_module_names = list(fused_module_names)
else:
for module in modules_to_fuse:
attr_name = getattr(module, "kernel_layer_name", None)
if attr_name is None:
raise ValueError(
f"Module {module} does not have a 'kernel_layer_name' attribute. "
f"Either decorate it with @use_kernel_forward_from_hub or provide "
f"explicit names via the inline pattern format: "
f'(("<name>", "<glob_path>"), ...).'
)
self.add_module(attr_name, module)
self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse]

# `kernelize` validates the kernel's forward signature against the class being replaced.
# Since the fused container sits at the position of the first module in the chain, the
# kernel's forward must match that module's signature. We patch the class-level forward
# here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it.
# The body raises because this forward is always replaced by the kernel before any call.
@functools.wraps(type(modules_to_fuse[0]).forward)
def forward(self, *args, **kwargs):
raise NotImplementedError("FusedModule is a placeholder and should not be called directly.")

self.__class__.forward = forward

def __repr__(self):
names = ", ".join(self._fused_module_names)
return f"{self.__class__.__name__}(fused=({names}))"


@functools.cache
def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type:
"""
Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination.

Args:
source_layer_names (`tuple[str, ...]`):
Ordered tuple of `kernel_layer_name` values of the modules being fused
(e.g. `("RMSNorm", "MLP")`). Used as the cache key — the same combination
always returns the same class object.
kernel_layer_name (`str`):
The name assigned to the fused class, used by `kernelize` to look up the
kernel in the mapping (e.g. `"RMSNormMLP"`).

Returns:
A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute.
"""
return type(
f"Fused_{'_'.join(source_layer_names)}",
(FusedModuleBase,),
{"kernel_layer_name": kernel_layer_name},
)


def fuse_modules(
model: "nn.Module",
module_names_to_fuse: list[str],
kernel_layer_name: str,
source_layer_names: list[str] | None = None,
) -> None:
"""
Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place.

For every parent module whose immediate children match all entries in
`module_names_to_fuse`, the function:

- replaces the first module with a `FusedModuleBase` subclass instance that holds
all source modules as named children,
- replaces the remaining modules with `nn.Identity()` pass-throughs.

The fused container's `forward` signature is patched to match the first source
module's `forward`, satisfying the `kernelize` signature check.

Args:
model (`nn.Module`):
The model to modify in-place.
module_names_to_fuse (`list[str]`):
Glob-style paths of the modules to fuse, e.g.
`["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]`.
Integer indices are replaced with `*` so the same pattern applies to
every repeated block.
kernel_layer_name (`str`):
The `kernel_layer_name` assigned to the fused class, used by `kernelize`
to look up the kernel in the mapping (e.g. `"RMSNormMLP"`).
source_layer_names (`list[str]`, *optional*):
Explicit names for the child modules inside the fused container
(e.g. `["RMSNorm", "MLP"]`). When `None`, the `kernel_layer_name`
attribute of each source module is used.
"""
pattern = re.compile(r"\d+")
for module_name, module in model.named_modules():
generic_children = {
re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child)
for n, child in module.named_children()
}
if not all(p in generic_children for p in module_names_to_fuse):
continue

child_names = [generic_children[p][0] for p in module_names_to_fuse]
modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse]

if source_layer_names is not None:
resolved_names = tuple(source_layer_names)
else:
resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse)
FusedClass = make_fused_module_class(resolved_names, kernel_layer_name)
fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names))

module.add_module(child_names[0], fused_instance)
for child_name in child_names[1:]:
module.add_module(child_name, nn.Identity())
Comment on lines +673 to +675
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're probably gonna have concerns with hooks, especially with TP potentially but also accelerate!



def unfuse_modules(model: "nn.Module") -> None:
"""
Revert a previous `fuse_modules` call in-place, restoring the original modules.

For each `FusedModuleBase` instance found in the model tree, the function:

- restores the original first module at the fused container's position,
- restores the remaining original modules at their original positions
(replacing the `nn.Identity()` pass-throughs).

Args:
model (`nn.Module`): The model to restore in-place.
"""
for parent in model.modules():
for name, child in list(parent.named_children()):
if not isinstance(child, FusedModuleBase):
continue
orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names]
parent.add_module(name, orig_modules[0])
for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]):
parent.add_module(sibling_name, orig_module)


__all__ = [
"FusedModuleBase",
"LayerRepository",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"fuse_modules",
"get_kernel",
"lazy_load_kernel",
"make_fused_module_class",
"register_fusion_patterns",
"register_kernel_mapping",
"register_kernel_mapping_transformers",
"replace_kernel_forward_from_hub",
"lazy_load_kernel",
"get_kernel",
"unfuse_modules",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"use_kernelized_func",
] # type: ignore
4 changes: 4 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3738,6 +3738,10 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None
register_kernel_mapping_transformers()

if kernel_config is not None and isinstance(kernel_config, KernelConfig):
# For n-to-1 entries (tuple keys), fuse the corresponding modules in the model
# before validation so that FusedModuleBase instances are visible to sanitize.
kernel_config.apply_fusions(self)

# This will make sure the mapping is valid, and the layers are registered in the model
kernel_config.sanitize_kernel_mapping(self)

Expand Down
49 changes: 49 additions & 0 deletions src/transformers/utils/kernel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..integrations.hub_kernels import _FUSION_PATTERNS_REGISTRY, fuse_modules
from ..utils import PushToHubMixin


Expand Down Expand Up @@ -116,6 +117,54 @@ def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revi
}
}

def apply_fusions(self, model):
"""
For each n-to-1 entry (tuple key) in the kernel mapping, find the fusion patterns
registered on the model, fuse the corresponding modules in-place, then replace the
tuple key with the resolved kernel layer name so the rest of the pipeline is unchanged.
"""
new_mapping = {}
for layer_name, kernel in self.kernel_mapping.items():
if not isinstance(layer_name, tuple):
new_mapping[layer_name] = kernel
continue

# Parse the target kernel layer name from the repo string (the part after ':')
repo_str = kernel if isinstance(kernel, str) else next(iter(kernel.values()))
if isinstance(repo_str, dict):
repo_str = next(iter(repo_str.values()))
kernel_layer_name = repo_str.split(":")[1] if ":" in repo_str else repo_str.split("/")[-1]

# Only fuse if a kernel is available for the current device
current_device = infer_device(model)
has_kernel_for_device = isinstance(kernel, str) or current_device in kernel
if not has_kernel_for_device:
continue

# Detect inline format: tuple of (kernel_layer_name, glob_pattern) pairs.
# e.g. (("RMSNorm", "model.layers.*.post_attention_layernorm"), ("MLP", "model.layers.*.mlp"))
is_inline = all(isinstance(item, tuple) and len(item) == 2 for item in layer_name)
if is_inline:
source_names = [item[0] for item in layer_name]
patterns = [item[1] for item in layer_name]
fuse_modules(model, patterns, kernel_layer_name, source_layer_names=source_names)
else:
fusion_patterns = getattr(model, "_kernel_fusion_patterns", None) or _FUSION_PATTERNS_REGISTRY.get(
type(model), {}
)
if kernel_layer_name not in fusion_patterns:
raise ValueError(
f"{type(model).__name__} does not define fusion patterns for '{kernel_layer_name}'. "
f'Either add `_kernel_fusion_patterns = {{"{kernel_layer_name}": [...]}}` to the model class, '
f"call `register_fusion_patterns({type(model).__name__}, ...)` before loading the model, "
f"or use the inline pattern format: "
f'(("{kernel_layer_name}", "<glob_path>"), ...).'
)
fuse_modules(model, fusion_patterns[kernel_layer_name], kernel_layer_name)
new_mapping[kernel_layer_name] = kernel

self.kernel_mapping = new_mapping

def store_registered_layer_names(self, model):
for name, module in model.named_modules():
if hasattr(module, "kernel_layer_name"):
Expand Down
Loading