Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions src/transformers/fusion_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

import torch
from torch import nn

from .activations import ACT2FN
from .core_model_loading import Concatenate, WeightConverter
from .monkey_patching import register_patch_mapping


if TYPE_CHECKING:
from .modeling_utils import PreTrainedModel


def _make_fused_mlp(original_cls):
"""
Create a fused MLP class from an original that has separate ``gate_proj`` and
``up_proj``. The returned subclass replaces those with a single ``gate_up_proj``
and overrides ``_compute_gate_up`` to use it.
"""

class ActAndMul(nn.Module):
"""
The module computes x -> act(x[:d]) * x[d:] where d = x.shape[-1] // 2.

Shapes:
x: (..., 2 * d)
return: (..., d)
"""
def __init__(self, act):
super().__init__()
self.act = act
ActAndMul.__name__ = f"{type(act).__name__}AndMul"
ActAndMul.__qualname__ = f"{type(act).__qualname__}AndMul"

def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
return self.act(input[..., :d]) * input[..., d:]

class FusedMLP(original_cls):
_weight_converter = WeightConverter(
source_patterns=[".gate_proj.weight$", ".up_proj.weight$"],
target_patterns=".gate_up_proj.weight$",
operations=[Concatenate(dim=0)],
)

def __init__(self, config):
super().__init__(config)

in_features = self.gate_proj.in_features
out_features = self.gate_proj.out_features + self.up_proj.out_features
bias = self.gate_proj.bias is not None
del self.gate_proj, self.up_proj

self.gate_up_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
self.act_fn = ActAndMul(self.act_fn)

def _compute_gate_up(self, x):
return self.act_fn(self.gate_up_proj(x))

FusedMLP.__name__ = f"Fused{original_cls.__name__}"
FusedMLP.__qualname__ = f"Fused{original_cls.__qualname__}"
return FusedMLP


def _make_fused_attention(original_cls):
"""
Create a fused attention class from an original that has separate ``q_proj``,
``k_proj``, ``v_proj``. The returned subclass replaces those with a single
``qkv_proj`` and overrides ``_project_qkv`` to split the fused output.
"""

class FusedAttention(original_cls):
_weight_converter = WeightConverter(
source_patterns=[".q_proj.weight$", ".k_proj.weight$", ".v_proj.weight$"],
target_patterns=".qkv_proj.weight$",
operations=[Concatenate(dim=0)],
)

def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)

in_features = self.q_proj.in_features
self.q_size = self.q_proj.out_features
self.kv_size = self.k_proj.out_features
out_features = self.q_size + 2 * self.kv_size
bias = self.q_proj.bias is not None
del self.q_proj, self.k_proj, self.v_proj

self.qkv_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)

def _project_qkv(self, hidden_states, hidden_shape):
qkv = self.qkv_proj(hidden_states)
q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1)
return (
q.view(hidden_shape).transpose(1, 2),
k.view(hidden_shape).transpose(1, 2),
v.view(hidden_shape).transpose(1, 2),
)

FusedAttention.__name__ = f"Fused{original_cls.__name__}"
FusedAttention.__qualname__ = f"Fused{original_cls.__qualname__}"
return FusedAttention


_fusion_cache: dict[type, dict[str, type[nn.Module]]] = {}


def _discover_fusable_classes(cls: "type[PreTrainedModel]", config) -> dict[str, type[nn.Module]]:
"""
Instantiate *cls* on the meta device and walk ``model.modules()`` to find
``nn.Module`` subclasses that expose ``_compute_gate_up`` (MLP fusion) or
``_project_qkv`` (attention fusion). Returns a mapping from original
class name → fused replacement class.
"""
if cls in _fusion_cache:
return _fusion_cache[cls]

with torch.device("meta"):
model = cls(config)

seen: set[type] = set()
patch_mapping: dict[str, type[nn.Module]] = {}
for submodule in model.modules():
subcls = type(submodule)
if subcls in seen:
continue
seen.add(subcls)
if hasattr(subcls, "_compute_gate_up"):
patch_mapping[subcls.__name__] = _make_fused_mlp(subcls)
elif hasattr(subcls, "_project_qkv"):
patch_mapping[subcls.__name__] = _make_fused_attention(subcls)

_fusion_cache[cls] = patch_mapping
return patch_mapping


def _update_tp_plan(tp_plan: dict[str, str]) -> None:
"""Rewrite *tp_plan* in-place to reflect fused projections."""
for q_key in [k for k in tp_plan if k.endswith(".q_proj")]:
prefix = q_key.rsplit(".q_proj", 1)[0]
k_key, v_key = f"{prefix}.k_proj", f"{prefix}.v_proj"
if k_key in tp_plan and v_key in tp_plan:
del tp_plan[q_key], tp_plan[k_key], tp_plan[v_key]
tp_plan[f"{prefix}.qkv_proj"] = "colwise_qkv"

for gate_key in [k for k in tp_plan if k.endswith(".gate_proj")]:
prefix = gate_key.rsplit(".gate_proj", 1)[0]
up_key = f"{prefix}.up_proj"
if up_key in tp_plan:
del tp_plan[gate_key], tp_plan[up_key]
tp_plan[f"{prefix}.gate_up_proj"] = "colwise_merged"


def register_fusion_patches(cls: "type[PreTrainedModel]", config) -> None:
"""
Register all fusion-related changes for *cls*:

1. Monkey-patches into the global patch mapping (for ``apply_patches()``)
2. Weight converters into the checkpoint conversion mapping
(for ``get_model_conversion_mapping()``)
3. TP plan updates on ``config_class.base_model_tp_plan``
"""
from .conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping

fusable_classes = _discover_fusable_classes(cls, config)
if not fusable_classes:
return

# 1. monkey-patches
register_patch_mapping(fusable_classes)

# 2. weight converters
config_class = getattr(cls, "config_class", None)
model_type = getattr(config_class, "model_type", None) if config_class is not None else None
if model_type is not None:
converters = [fused_cls._weight_converter for fused_cls in fusable_classes.values()]
existing = get_checkpoint_conversion_mapping(model_type)
if existing is not None:
converters = existing + converters
register_checkpoint_conversion_mapping(model_type, converters, overwrite=True)

# 3. tp plan
if config_class is not None and config_class.base_model_tp_plan is not None:
_update_tp_plan(config_class.base_model_tp_plan)
10 changes: 10 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3689,6 +3689,7 @@ def from_pretrained(
revision: str = "main",
use_safetensors: bool | None = None,
weights_only: bool = True,
fuse_layers: bool = False,
**kwargs,
) -> SpecificPreTrainedModelType:
r"""
Expand Down Expand Up @@ -3867,6 +3868,8 @@ def from_pretrained(
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via torch.serialization.add_safe_globals().
When set to False, we can load wrapper tensor subclass weights.
fuse_layers (`bool`, *optional*, defaults to `False`):
Whether or not to fuse some layers of the model when loading it. This should only be used as an inference optimization.
key_mapping (`dict[str, str], *optional*):
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
architecture, but was not converted accordingly.
Expand Down Expand Up @@ -4080,6 +4083,13 @@ def from_pretrained(
)

config.name_or_path = pretrained_model_name_or_path

# Register any fusion patch mappings necessary to fuse layers at initializzation for inference performance
if fuse_layers:
from .fusion_mapping import register_fusion_patches

register_fusion_patches(cls, config)

model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,11 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def _compute_gate_up(self, x):
return self.act_fn(self.gate_proj(x)) * self.up_proj(x)

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
return self.down_proj(self._compute_gate_up(x))


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -248,6 +250,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)

def _project_qkv(self, hidden_states, hidden_shape):
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
return query_states, key_states, value_states

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -259,9 +267,7 @@ def forward(
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states, value_states = self._project_qkv(hidden_states, hidden_shape)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
Expand Down
Loading