Skip to content

Module Fusion API#44979

Open
michaelbenayoun wants to merge 5 commits intohuggingface:mainfrom
michaelbenayoun:fuse_modules
Open

Module Fusion API#44979
michaelbenayoun wants to merge 5 commits intohuggingface:mainfrom
michaelbenayoun:fuse_modules

Conversation

@michaelbenayoun
Copy link
Copy Markdown
Member

@michaelbenayoun michaelbenayoun commented Mar 24, 2026

What does this PR do?

Introduces src/transformers/module_fusion.py, a utility for fusing adjacent submodules in a model into a single FusedModule that executes them as a chain in one forward pass. The key
components are:

  • RegistryCollector: a transparent pass-through that captures inputs into a shared registry
  • FusedModule: re-executes the full chain using registry-captured inputs, wiring outputs of one module as inputs to the next
  • fuse_modules / unfuse_modules: in-place model surgery using glob-style paths (e.g. "layers.*.linear") to target repeated blocks.
  • ModuleSpec: describes the named inputs/outputs for each module in the chain; optional params can be omitted from the spec and filled by apply_defaults()

Motivation and example

The initial motivation with this is to provide a way to fuse multiple layers into one layer and then make this layer use a custom kernel with the kernels library. This way we can use fused kernels in Transformers.

import copy

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.module_fusion import ModuleSpec, fuse_modules, unfuse_modules


model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
original = copy.deepcopy(model)

fuse_modules(
    model,
    [
        "model.layers.*.post_attention_layernorm", 
        "model.layers.*.mlp"
    ],
    [
        ModuleSpec(inputs=["hidden_states"], outputs=["hidden_states"]),  # Qwen3MLP
        ModuleSpec(inputs=["hidden_states"], outputs=["hidden_states"]),  # Qwen3RMSNorm
    ]
)

input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids

with torch.no_grad():
    fused_out = model(input_ids).logits
    original_out = original(input_ids).logits

print("Max diff fused vs original:", (fused_out - original_out).abs().max().item())

unfuse_modules(model)

with torch.no_grad():
    unfused_out = model(input_ids).logits

print("Max diff unfused vs original:", (unfused_out - original_out).abs().max().item())

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44979&sha=79dd4b

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels overly complex for what it's doing.

The whole registry/collector machinery exists just to capture inputs at runtime so FusedModule can replay the chain itself, but we already have APIs for this: OutputCollector.

Why not:

  • monkey-patch the second module's class to absorb the first module's weights (dynamic weight converter handles the mapping)
  • register the checkpoint conversion so weights load correctly
  • decorate the patched class with the kernel
  • auto replace class A with a custom nn.Identity that just returns everything? (we have kwargs all over transformers).
  • Best would be to replace calls B with indentity A is fused

The fused class just owns all the weights and its forward runs the kernel.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

  import copy

  import torch
  import torch.nn as nn

  from transformers import AutoModelForCausalLM, AutoTokenizer
  from transformers.conversion_mapping import WeightRenaming, register_checkpoint_conversion_mapping
  from transformers.integrations import use_kernel_forward_from_hub
  from transformers.monkey_patching import register_patch_mapping
  from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3MLP


  # FusedNormMLP absorbs post_attention_layernorm weights into the MLP.
  # The @use_kernel_forward_from_hub decorator is where a fused kernel plugs in.
  @use_kernel_forward_from_hub("NormMLP")
  class FusedNormMLP(Qwen3MLP):
      def __init__(self, config):
          super().__init__(config)
          self.norm_weight = nn.Parameter(torch.ones(config.hidden_size))
          self.norm_eps = config.rms_norm_eps

      def forward(self, x):
          # fused: norm then MLP — replaced end-to-end by the kernel
          variance = x.pow(2).mean(-1, keepdim=True)
          x = x * torch.rsqrt(variance + self.norm_eps) * self.norm_weight
          return super().forward(x)


  # Decoder layer that skips the now-redundant post_attention_layernorm call
  # (FusedNormMLP owns the norm internally)
  class FusedQwen3DecoderLayer(Qwen3DecoderLayer):
      def forward(self, hidden_states, attention_mask=None, position_ids=None,
                  past_key_values=None, use_cache=False, position_embeddings=None, **kwargs):
          residual = hidden_states
          hidden_states = self.input_layernorm(hidden_states)
          hidden_states, _ = self.self_attn(
              hidden_states=hidden_states, attention_mask=attention_mask,
              position_ids=position_ids, past_key_values=past_key_values,
              use_cache=use_cache, position_embeddings=position_embeddings, **kwargs,
          )
          hidden_states = residual + hidden_states

          residual = hidden_states
          # post_attention_layernorm is now inside FusedNormMLP — skip it here
          hidden_states = self.mlp(hidden_states)
          hidden_states = residual + hidden_states
          return hidden_states


  # Wire up: swap classes and redirect the norm weights into mlp.norm_weight on load
  register_patch_mapping({
      "Qwen3MLP": FusedNormMLP,
      "Qwen3DecoderLayer": FusedQwen3DecoderLayer,
  })

  register_checkpoint_conversion_mapping(
      "qwen3",
      [WeightRenaming(
          source_patterns=r"post_attention_layernorm\.weight",
          target_patterns="mlp.norm_weight",
      )],
  )

  model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
  tokenizer = AutoTokenizer.from_pretrained(model_id)

  # patched model loads with norm weights remapped into mlp.norm_weight automatically
  model = AutoModelForCausalLM.from_pretrained(model_id)
  original = AutoModelForCausalLM.from_pretrained(model_id)  # load again before patches affect it

  input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids

  with torch.no_grad():
      fused_out = model(input_ids).logits
      original_out = original(input_ids).logits

  print("Max diff fused vs original:", (fused_out - original_out).abs().max().item())
  # expected: 0.0

@michaelbenayoun
Copy link
Copy Markdown
Member Author

This feels overly complex for what it's doing.

The whole registry/collector machinery exists just to capture inputs at runtime so FusedModule can replay the chain itself, but we already have APIs for this: OutputCollector.

Why not:

  • monkey-patch the second module's class to absorb the first module's weights (dynamic weight converter handles the mapping)
  • register the checkpoint conversion so weights load correctly
  • decorate the patched class with the kernel
  • auto replace class A with a custom nn.Identity that just returns everything? (we have kwargs all over transformers).
  • Best would be to replace calls B with indentity A is fused

The fused class just owns all the weights and its forward runs the kernel.

I could not find the OutputCollector mentioned, do yo have a link please? Also we want to capture the inputs, not the outputs, can it still be used?

As an answer to the rest, I do not really see how the suggested solution, which involves:

  • Defining a class everytime we want to fuse a module
  • Monkey patching classes
  • Handling weights transformations

is less overly complex for what it is doing. Yes it is possible to do it the way you mentioned, but it is much more heavy than simply dynamically changing in the model only a few modules. There is no weight conversion, no custom class definition for fusing, no monkey patching pattern.

That being said, I got your feedbacks. The primary motivation for this piece of code is to enable easy use of fused kernels. As you can see the code is completely standalone and independent from transformers. The next step for me will be to see what's the best approach to adopt for the specific project I am working on.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

We went with #45041 in the mean time, you can have a look I think it should fit your usecase!

@michaelbenayoun
Copy link
Copy Markdown
Member Author

Oh alright, in the mean time I worked on this #45363.
Happy to get your opinion on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants