Module Fusion API#44979
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44979&sha=79dd4b |
ArthurZucker
left a comment
There was a problem hiding this comment.
This feels overly complex for what it's doing.
The whole registry/collector machinery exists just to capture inputs at runtime so FusedModule can replay the chain itself, but we already have APIs for this: OutputCollector.
Why not:
- monkey-patch the second module's class to absorb the first module's weights (dynamic weight converter handles the mapping)
- register the checkpoint conversion so weights load correctly
- decorate the patched class with the kernel
- auto replace class A with a custom nn.Identity that just returns everything? (we have kwargs all over transformers).
- Best would be to replace calls B with indentity A is fused
The fused class just owns all the weights and its forward runs the kernel.
import copy
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.conversion_mapping import WeightRenaming, register_checkpoint_conversion_mapping
from transformers.integrations import use_kernel_forward_from_hub
from transformers.monkey_patching import register_patch_mapping
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3MLP
# FusedNormMLP absorbs post_attention_layernorm weights into the MLP.
# The @use_kernel_forward_from_hub decorator is where a fused kernel plugs in.
@use_kernel_forward_from_hub("NormMLP")
class FusedNormMLP(Qwen3MLP):
def __init__(self, config):
super().__init__(config)
self.norm_weight = nn.Parameter(torch.ones(config.hidden_size))
self.norm_eps = config.rms_norm_eps
def forward(self, x):
# fused: norm then MLP — replaced end-to-end by the kernel
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.norm_eps) * self.norm_weight
return super().forward(x)
# Decoder layer that skips the now-redundant post_attention_layernorm call
# (FusedNormMLP owns the norm internally)
class FusedQwen3DecoderLayer(Qwen3DecoderLayer):
def forward(self, hidden_states, attention_mask=None, position_ids=None,
past_key_values=None, use_cache=False, position_embeddings=None, **kwargs):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask,
position_ids=position_ids, past_key_values=past_key_values,
use_cache=use_cache, position_embeddings=position_embeddings, **kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
# post_attention_layernorm is now inside FusedNormMLP — skip it here
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Wire up: swap classes and redirect the norm weights into mlp.norm_weight on load
register_patch_mapping({
"Qwen3MLP": FusedNormMLP,
"Qwen3DecoderLayer": FusedQwen3DecoderLayer,
})
register_checkpoint_conversion_mapping(
"qwen3",
[WeightRenaming(
source_patterns=r"post_attention_layernorm\.weight",
target_patterns="mlp.norm_weight",
)],
)
model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# patched model loads with norm weights remapped into mlp.norm_weight automatically
model = AutoModelForCausalLM.from_pretrained(model_id)
original = AutoModelForCausalLM.from_pretrained(model_id) # load again before patches affect it
input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids
with torch.no_grad():
fused_out = model(input_ids).logits
original_out = original(input_ids).logits
print("Max diff fused vs original:", (fused_out - original_out).abs().max().item())
# expected: 0.0 |
I could not find the As an answer to the rest, I do not really see how the suggested solution, which involves:
is less overly complex for what it is doing. Yes it is possible to do it the way you mentioned, but it is much more heavy than simply dynamically changing in the model only a few modules. There is no weight conversion, no custom class definition for fusing, no monkey patching pattern. That being said, I got your feedbacks. The primary motivation for this piece of code is to enable easy use of fused kernels. As you can see the code is completely standalone and independent from transformers. The next step for me will be to see what's the best approach to adopt for the specific project I am working on. |
|
We went with #45041 in the mean time, you can have a look I think it should fit your usecase! |
|
Oh alright, in the mean time I worked on this #45363. |
What does this PR do?
Introduces
src/transformers/module_fusion.py, a utility for fusing adjacent submodules in a model into a single FusedModule that executes them as a chain in one forward pass. The keycomponents are:
RegistryCollector: a transparent pass-through that captures inputs into a shared registryFusedModule: re-executes the full chain using registry-captured inputs, wiring outputs of one module as inputs to the nextfuse_modules/unfuse_modules: in-place model surgery using glob-style paths (e.g. "layers.*.linear") to target repeated blocks.ModuleSpec: describes the named inputs/outputs for each module in the chain; optional params can be omitted from the spec and filled by apply_defaults()Motivation and example
The initial motivation with this is to provide a way to fuse multiple layers into one layer and then make this layer use a custom kernel with the
kernelslibrary. This way we can use fused kernels in Transformers.