diff --git a/src/transformers/fusion_mapping.py b/src/transformers/fusion_mapping.py new file mode 100644 index 000000000000..4ecf92613102 --- /dev/null +++ b/src/transformers/fusion_mapping.py @@ -0,0 +1,199 @@ +# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from .activations import ACT2FN +from .core_model_loading import Concatenate, WeightConverter +from .monkey_patching import register_patch_mapping + + +if TYPE_CHECKING: + from .modeling_utils import PreTrainedModel + + +def _make_fused_mlp(original_cls): + """ + Create a fused MLP class from an original that has separate ``gate_proj`` and + ``up_proj``. The returned subclass replaces those with a single ``gate_up_proj`` + and overrides ``_compute_gate_up`` to use it. + """ + + class ActAndMul(nn.Module): + """ + The module computes x -> act(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (..., 2 * d) + return: (..., d) + """ + def __init__(self, act): + super().__init__() + self.act = act + ActAndMul.__name__ = f"{type(act).__name__}AndMul" + ActAndMul.__qualname__ = f"{type(act).__qualname__}AndMul" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + d = input.shape[-1] // 2 + return self.act(input[..., :d]) * input[..., d:] + + class FusedMLP(original_cls): + _weight_converter = WeightConverter( + source_patterns=[".gate_proj.weight$", ".up_proj.weight$"], + target_patterns=".gate_up_proj.weight$", + operations=[Concatenate(dim=0)], + ) + + def __init__(self, config): + super().__init__(config) + + in_features = self.gate_proj.in_features + out_features = self.gate_proj.out_features + self.up_proj.out_features + bias = self.gate_proj.bias is not None + del self.gate_proj, self.up_proj + + self.gate_up_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) + self.act_fn = ActAndMul(self.act_fn) + + def _compute_gate_up(self, x): + return self.act_fn(self.gate_up_proj(x)) + + FusedMLP.__name__ = f"Fused{original_cls.__name__}" + FusedMLP.__qualname__ = f"Fused{original_cls.__qualname__}" + return FusedMLP + + +def _make_fused_attention(original_cls): + """ + Create a fused attention class from an original that has separate ``q_proj``, + ``k_proj``, ``v_proj``. The returned subclass replaces those with a single + ``qkv_proj`` and overrides ``_project_qkv`` to split the fused output. + """ + + class FusedAttention(original_cls): + _weight_converter = WeightConverter( + source_patterns=[".q_proj.weight$", ".k_proj.weight$", ".v_proj.weight$"], + target_patterns=".qkv_proj.weight$", + operations=[Concatenate(dim=0)], + ) + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + in_features = self.q_proj.in_features + self.q_size = self.q_proj.out_features + self.kv_size = self.k_proj.out_features + out_features = self.q_size + 2 * self.kv_size + bias = self.q_proj.bias is not None + del self.q_proj, self.k_proj, self.v_proj + + self.qkv_proj = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) + + def _project_qkv(self, hidden_states, hidden_shape): + qkv = self.qkv_proj(hidden_states) + q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) + return ( + q.view(hidden_shape).transpose(1, 2), + k.view(hidden_shape).transpose(1, 2), + v.view(hidden_shape).transpose(1, 2), + ) + + FusedAttention.__name__ = f"Fused{original_cls.__name__}" + FusedAttention.__qualname__ = f"Fused{original_cls.__qualname__}" + return FusedAttention + + +_fusion_cache: dict[type, dict[str, type[nn.Module]]] = {} + + +def _discover_fusable_classes(cls: "type[PreTrainedModel]", config) -> dict[str, type[nn.Module]]: + """ + Instantiate *cls* on the meta device and walk ``model.modules()`` to find + ``nn.Module`` subclasses that expose ``_compute_gate_up`` (MLP fusion) or + ``_project_qkv`` (attention fusion). Returns a mapping from original + class name → fused replacement class. + """ + if cls in _fusion_cache: + return _fusion_cache[cls] + + with torch.device("meta"): + model = cls(config) + + seen: set[type] = set() + patch_mapping: dict[str, type[nn.Module]] = {} + for submodule in model.modules(): + subcls = type(submodule) + if subcls in seen: + continue + seen.add(subcls) + if hasattr(subcls, "_compute_gate_up"): + patch_mapping[subcls.__name__] = _make_fused_mlp(subcls) + elif hasattr(subcls, "_project_qkv"): + patch_mapping[subcls.__name__] = _make_fused_attention(subcls) + + _fusion_cache[cls] = patch_mapping + return patch_mapping + + +def _update_tp_plan(tp_plan: dict[str, str]) -> None: + """Rewrite *tp_plan* in-place to reflect fused projections.""" + for q_key in [k for k in tp_plan if k.endswith(".q_proj")]: + prefix = q_key.rsplit(".q_proj", 1)[0] + k_key, v_key = f"{prefix}.k_proj", f"{prefix}.v_proj" + if k_key in tp_plan and v_key in tp_plan: + del tp_plan[q_key], tp_plan[k_key], tp_plan[v_key] + tp_plan[f"{prefix}.qkv_proj"] = "colwise_qkv" + + for gate_key in [k for k in tp_plan if k.endswith(".gate_proj")]: + prefix = gate_key.rsplit(".gate_proj", 1)[0] + up_key = f"{prefix}.up_proj" + if up_key in tp_plan: + del tp_plan[gate_key], tp_plan[up_key] + tp_plan[f"{prefix}.gate_up_proj"] = "colwise_merged" + + +def register_fusion_patches(cls: "type[PreTrainedModel]", config) -> None: + """ + Register all fusion-related changes for *cls*: + + 1. Monkey-patches into the global patch mapping (for ``apply_patches()``) + 2. Weight converters into the checkpoint conversion mapping + (for ``get_model_conversion_mapping()``) + 3. TP plan updates on ``config_class.base_model_tp_plan`` + """ + from .conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping + + fusable_classes = _discover_fusable_classes(cls, config) + if not fusable_classes: + return + + # 1. monkey-patches + register_patch_mapping(fusable_classes) + + # 2. weight converters + config_class = getattr(cls, "config_class", None) + model_type = getattr(config_class, "model_type", None) if config_class is not None else None + if model_type is not None: + converters = [fused_cls._weight_converter for fused_cls in fusable_classes.values()] + existing = get_checkpoint_conversion_mapping(model_type) + if existing is not None: + converters = existing + converters + register_checkpoint_conversion_mapping(model_type, converters, overwrite=True) + + # 3. tp plan + if config_class is not None and config_class.base_model_tp_plan is not None: + _update_tp_plan(config_class.base_model_tp_plan) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1cdb033cb709..75b488bc7310 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3689,6 +3689,7 @@ def from_pretrained( revision: str = "main", use_safetensors: bool | None = None, weights_only: bool = True, + fuse_layers: bool = False, **kwargs, ) -> SpecificPreTrainedModelType: r""" @@ -3867,6 +3868,8 @@ def from_pretrained( Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights. + fuse_layers (`bool`, *optional*, defaults to `False`): + Whether or not to fuse some layers of the model when loading it. This should only be used as an inference optimization. key_mapping (`dict[str, str], *optional*): A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly. @@ -4080,6 +4083,13 @@ def from_pretrained( ) config.name_or_path = pretrained_model_name_or_path + + # Register any fusion patch mappings necessary to fuse layers at initializzation for inference performance + if fuse_layers: + from .fusion_mapping import register_fusion_patches + + register_fusion_patches(cls, config) + model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9d659c7c6f08..7486ae4a3245 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -179,9 +179,11 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] + def _compute_gate_up(self, x): + return self.act_fn(self.gate_proj(x)) * self.up_proj(x) + def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj(self._compute_gate_up(x)) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -248,6 +250,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + def _project_qkv(self, hidden_states, hidden_shape): + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + return query_states, key_states, value_states + def forward( self, hidden_states: torch.Tensor, @@ -259,9 +267,7 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states, key_states, value_states = self._project_qkv(hidden_states, hidden_shape) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)