From e75c44864693ff2f767911e67079a0b0ef76efe5 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:10:54 +0100 Subject: [PATCH 1/4] Add `_and_mul` capability to `activations.ClassInstantier` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/activations.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index a51ebca341d4..ccab30e0b4fb 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -216,9 +216,35 @@ def forward(self, input): class ClassInstantier(OrderedDict): def __getitem__(self, key): + key, _, operation = key.partition("_and_") content = super().__getitem__(key) cls, kwargs = content if isinstance(content, tuple) else (content, {}) - return cls(**kwargs) + act = cls(**kwargs) + if operation == "mul": + class ActAndMul(nn.Module): + """ + The module computes x -> act(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (..., 2 * d) + return: (..., d) + """ + def __init__(self): + super().__init__() + self.act = act + + def forward(self, input: Tensor) -> Tensor: + d = input.shape[-1] // 2 + return self.act(input[..., :d]) * input[..., d:] + + ActAndMul.__name__ = f"{type(act).__name__}AndMul" + ActAndMul.__qualname__ = f"{type(act).__qualname__}AndMul" + act = ActAndMul() + elif operation: + raise ValueError( + f"Invalid {operation=} to fuse with activation {cls.__name__}. Only 'mul' is supported." + ) + return act class XIELUActivation(nn.Module): From 494d00a97caa0211b87f8c857ed68d68bb316062 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:47:59 +0100 Subject: [PATCH 2/4] Add `fuse_layers` to `PreTrainedModel.from_pretrained` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/fusion_mapping.py | 183 +++++++++++++++++++++++++++++ src/transformers/modeling_utils.py | 10 ++ 2 files changed, 193 insertions(+) create mode 100644 src/transformers/fusion_mapping.py diff --git a/src/transformers/fusion_mapping.py b/src/transformers/fusion_mapping.py new file mode 100644 index 000000000000..c2a6466a7d41 --- /dev/null +++ b/src/transformers/fusion_mapping.py @@ -0,0 +1,183 @@ +# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from .activations import ACT2FN +from .core_model_loading import Concatenate, WeightConverter +from .monkey_patching import register_patch_mapping + + +if TYPE_CHECKING: + from .modeling_utils import PreTrainedModel + + +def _make_fused_mlp(original_cls): + """ + Create a fused MLP class from an original that has separate ``gate_proj`` and + ``up_proj``. The returned subclass replaces those with a single ``gate_up_proj`` + and overrides ``_compute_gate_up`` to use it. + """ + + class FusedMLP(original_cls): + _weight_converter = WeightConverter( + source_patterns=[".gate_proj.weight$", ".up_proj.weight$"], + target_patterns=".gate_up_proj.weight$", + operations=[Concatenate(dim=0)], + ) + + def __init__(self, config): + super().__init__(config) + + del self.gate_proj + del self.up_proj + self.gate_up_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias + ) + self.act_fn = ACT2FN[config.hidden_act + "_and_mul"] + + def _compute_gate_up(self, x): + return self.act_fn(self.gate_up_proj(x)) + + FusedMLP.__name__ = f"Fused{original_cls.__name__}" + FusedMLP.__qualname__ = f"Fused{original_cls.__qualname__}" + return FusedMLP + + +def _make_fused_attention(original_cls): + """ + Create a fused attention class from an original that has separate ``q_proj``, + ``k_proj``, ``v_proj``. The returned subclass replaces those with a single + ``qkv_proj`` and overrides ``_project_qkv`` to split the fused output. + """ + + class FusedAttention(original_cls): + _weight_converter = WeightConverter( + source_patterns=[".q_proj.weight$", ".k_proj.weight$", ".v_proj.weight$"], + target_patterns=".qkv_proj.weight$", + operations=[Concatenate(dim=0)], + ) + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + del self.q_proj + del self.k_proj + del self.v_proj + + self.q_size = config.num_attention_heads * self.head_dim + self.kv_size = config.num_key_value_heads * self.head_dim + self.qkv_proj = nn.Linear( + config.hidden_size, + self.q_size + 2 * self.kv_size, + bias=config.attention_bias, + ) + + def _project_qkv(self, hidden_states, hidden_shape): + qkv = self.qkv_proj(hidden_states) + q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) + return ( + q.view(hidden_shape).transpose(1, 2), + k.view(hidden_shape).transpose(1, 2), + v.view(hidden_shape).transpose(1, 2), + ) + + FusedAttention.__name__ = f"Fused{original_cls.__name__}" + FusedAttention.__qualname__ = f"Fused{original_cls.__qualname__}" + return FusedAttention + + +_fusion_cache: dict[type, dict[str, type[nn.Module]]] = {} + + +def _discover_fusable_classes(cls: "type[PreTrainedModel]", config) -> dict[str, type[nn.Module]]: + """ + Instantiate *cls* on the meta device and walk ``model.modules()`` to find + ``nn.Module`` subclasses that expose ``_compute_gate_up`` (MLP fusion) or + ``_project_qkv`` (attention fusion). Returns a mapping from original + class name → fused replacement class. + """ + if cls in _fusion_cache: + return _fusion_cache[cls] + + with torch.device("meta"): + model = cls(config) + + seen: set[type] = set() + patch_mapping: dict[str, type[nn.Module]] = {} + for submodule in model.modules(): + subcls = type(submodule) + if subcls in seen: + continue + seen.add(subcls) + if hasattr(subcls, "_compute_gate_up"): + patch_mapping[subcls.__name__] = _make_fused_mlp(subcls) + elif hasattr(subcls, "_project_qkv"): + patch_mapping[subcls.__name__] = _make_fused_attention(subcls) + + _fusion_cache[cls] = patch_mapping + return patch_mapping + + +def _update_tp_plan(tp_plan: dict[str, str]) -> None: + """Rewrite *tp_plan* in-place to reflect fused projections.""" + for q_key in [k for k in tp_plan if k.endswith(".q_proj")]: + prefix = q_key.rsplit(".q_proj", 1)[0] + k_key, v_key = f"{prefix}.k_proj", f"{prefix}.v_proj" + if k_key in tp_plan and v_key in tp_plan: + del tp_plan[q_key], tp_plan[k_key], tp_plan[v_key] + tp_plan[f"{prefix}.qkv_proj"] = "colwise_qkv" + + for gate_key in [k for k in tp_plan if k.endswith(".gate_proj")]: + prefix = gate_key.rsplit(".gate_proj", 1)[0] + up_key = f"{prefix}.up_proj" + if up_key in tp_plan: + del tp_plan[gate_key], tp_plan[up_key] + tp_plan[f"{prefix}.gate_up_proj"] = "colwise_merged" + + +def register_fusion_patches(cls: "type[PreTrainedModel]", config) -> None: + """ + Register all fusion-related changes for *cls*: + + 1. Monkey-patches into the global patch mapping (for ``apply_patches()``) + 2. Weight converters into the checkpoint conversion mapping + (for ``get_model_conversion_mapping()``) + 3. TP plan updates on ``config_class.base_model_tp_plan`` + """ + from .conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping + + fusable_classes = _discover_fusable_classes(cls, config) + if not fusable_classes: + return + + # 1. monkey-patches + register_patch_mapping(fusable_classes) + + # 2. weight converters + config_class = getattr(cls, "config_class", None) + model_type = getattr(config_class, "model_type", None) if config_class is not None else None + if model_type is not None: + converters = [fused_cls._weight_converter for fused_cls in fusable_classes.values()] + existing = get_checkpoint_conversion_mapping(model_type) + if existing is not None: + converters = existing + converters + register_checkpoint_conversion_mapping(model_type, converters, overwrite=True) + + # 3. tp plan + if config_class is not None and config_class.base_model_tp_plan is not None: + _update_tp_plan(config_class.base_model_tp_plan) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1cdb033cb709..75b488bc7310 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3689,6 +3689,7 @@ def from_pretrained( revision: str = "main", use_safetensors: bool | None = None, weights_only: bool = True, + fuse_layers: bool = False, **kwargs, ) -> SpecificPreTrainedModelType: r""" @@ -3867,6 +3868,8 @@ def from_pretrained( Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights. + fuse_layers (`bool`, *optional*, defaults to `False`): + Whether or not to fuse some layers of the model when loading it. This should only be used as an inference optimization. key_mapping (`dict[str, str], *optional*): A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly. @@ -4080,6 +4083,13 @@ def from_pretrained( ) config.name_or_path = pretrained_model_name_or_path + + # Register any fusion patch mappings necessary to fuse layers at initializzation for inference performance + if fuse_layers: + from .fusion_mapping import register_fusion_patches + + register_fusion_patches(cls, config) + model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. From bbf93c4ba7946d0702b06340615fec344567be91 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:48:09 +0100 Subject: [PATCH 3/4] Add fusion capability to Llama Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/models/llama/modeling_llama.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9d659c7c6f08..7486ae4a3245 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -179,9 +179,11 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] + def _compute_gate_up(self, x): + return self.act_fn(self.gate_proj(x)) * self.up_proj(x) + def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj(self._compute_gate_up(x)) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -248,6 +250,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + def _project_qkv(self, hidden_states, hidden_shape): + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + return query_states, key_states, value_states + def forward( self, hidden_states: torch.Tensor, @@ -259,9 +267,7 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states, key_states, value_states = self._project_qkv(hidden_states, hidden_shape) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) From 20e25719b972729efe1722c862d209903e7587ef Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 23 Mar 2026 12:42:38 +0100 Subject: [PATCH 4/4] Move activation fiddling to `fusion_mapping.py` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- src/transformers/activations.py | 28 +---------------- src/transformers/fusion_mapping.py | 48 ++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 43 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index ccab30e0b4fb..a51ebca341d4 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -216,35 +216,9 @@ def forward(self, input): class ClassInstantier(OrderedDict): def __getitem__(self, key): - key, _, operation = key.partition("_and_") content = super().__getitem__(key) cls, kwargs = content if isinstance(content, tuple) else (content, {}) - act = cls(**kwargs) - if operation == "mul": - class ActAndMul(nn.Module): - """ - The module computes x -> act(x[:d]) * x[d:] where d = x.shape[-1] // 2. - - Shapes: - x: (..., 2 * d) - return: (..., d) - """ - def __init__(self): - super().__init__() - self.act = act - - def forward(self, input: Tensor) -> Tensor: - d = input.shape[-1] // 2 - return self.act(input[..., :d]) * input[..., d:] - - ActAndMul.__name__ = f"{type(act).__name__}AndMul" - ActAndMul.__qualname__ = f"{type(act).__qualname__}AndMul" - act = ActAndMul() - elif operation: - raise ValueError( - f"Invalid {operation=} to fuse with activation {cls.__name__}. Only 'mul' is supported." - ) - return act + return cls(**kwargs) class XIELUActivation(nn.Module): diff --git a/src/transformers/fusion_mapping.py b/src/transformers/fusion_mapping.py index c2a6466a7d41..4ecf92613102 100644 --- a/src/transformers/fusion_mapping.py +++ b/src/transformers/fusion_mapping.py @@ -33,6 +33,24 @@ def _make_fused_mlp(original_cls): and overrides ``_compute_gate_up`` to use it. """ + class ActAndMul(nn.Module): + """ + The module computes x -> act(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (..., 2 * d) + return: (..., d) + """ + def __init__(self, act): + super().__init__() + self.act = act + ActAndMul.__name__ = f"{type(act).__name__}AndMul" + ActAndMul.__qualname__ = f"{type(act).__qualname__}AndMul" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + d = input.shape[-1] // 2 + return self.act(input[..., :d]) * input[..., d:] + class FusedMLP(original_cls): _weight_converter = WeightConverter( source_patterns=[".gate_proj.weight$", ".up_proj.weight$"], @@ -43,12 +61,13 @@ class FusedMLP(original_cls): def __init__(self, config): super().__init__(config) - del self.gate_proj - del self.up_proj - self.gate_up_proj = nn.Linear( - self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias - ) - self.act_fn = ACT2FN[config.hidden_act + "_and_mul"] + in_features = self.gate_proj.in_features + out_features = self.gate_proj.out_features + self.up_proj.out_features + bias = self.gate_proj.bias is not None + del self.gate_proj, self.up_proj + + self.gate_up_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) + self.act_fn = ActAndMul(self.act_fn) def _compute_gate_up(self, x): return self.act_fn(self.gate_up_proj(x)) @@ -75,17 +94,14 @@ class FusedAttention(original_cls): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) - del self.q_proj - del self.k_proj - del self.v_proj + in_features = self.q_proj.in_features + self.q_size = self.q_proj.out_features + self.kv_size = self.k_proj.out_features + out_features = self.q_size + 2 * self.kv_size + bias = self.q_proj.bias is not None + del self.q_proj, self.k_proj, self.v_proj - self.q_size = config.num_attention_heads * self.head_dim - self.kv_size = config.num_key_value_heads * self.head_dim - self.qkv_proj = nn.Linear( - config.hidden_size, - self.q_size + 2 * self.kv_size, - bias=config.attention_bias, - ) + self.qkv_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) def _project_qkv(self, hidden_states, hidden_shape): qkv = self.qkv_proj(hidden_states)