diff --git a/fused_qwen_example.py b/fused_qwen_example.py new file mode 100644 index 000000000000..0b7abd98439c --- /dev/null +++ b/fused_qwen_example.py @@ -0,0 +1,30 @@ +import copy +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig +from transformers.integrations import unfuse_modules + + +model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +kernel_config = KernelConfig({ + ( + ("RMSNorm", "model.layers.*.post_attention_layernorm"), + ("MLP", "model.layers.*.mlp"), + ): "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP", +}) + +model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda") + +input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device) + +original_model = copy.deepcopy(model) +unfuse_modules(original_model) +original_model.eval() + +with torch.no_grad(): + fused_out = model(input_ids).logits + original_out = original_model(input_ids).logits + +print("Max diff fused vs original:", (fused_out - original_out).abs().max().item()) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 336db3773f76..f74836d890f7 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -67,10 +67,15 @@ ], "hqq": ["prepare_for_hqq_linear"], "hub_kernels": [ + "FusedModuleBase", "LayerRepository", + "fuse_modules", "lazy_load_kernel", + "make_fused_module_class", + "register_fusion_patterns", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "unfuse_modules", "use_kernel_forward_from_hub", "use_kernel_func_from_hub", "use_kernelized_func", @@ -221,10 +226,15 @@ from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear from .hqq import prepare_for_hqq_linear from .hub_kernels import ( + FusedModuleBase, LayerRepository, + fuse_modules, lazy_load_kernel, + make_fused_module_class, + register_fusion_patterns, register_kernel_mapping, replace_kernel_forward_from_hub, + unfuse_modules, use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func, diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 70a343424aa8..e3ede38fd0eb 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import importlib.metadata import os import re @@ -21,10 +22,14 @@ from packaging import version as pkg_version from ..utils import ENV_VARS_TRUE_VALUES, logging -from ..utils.import_utils import is_kernels_available +from ..utils.import_utils import is_kernels_available, is_torch_available from .flash_attention import flash_attention_forward +if is_torch_available(): + import torch.nn as nn + + logger = logging.get_logger(__name__) try: @@ -500,14 +505,212 @@ def allow_all_hub_kernels(): ALLOW_ALL_KERNELS = False +# Model class → {kernel_layer_name → [glob patterns]} +# Populated via `register_fusion_patterns` for models that cannot be modified directly. +_FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} + + +def register_fusion_patterns( + model_class_or_instance, + patterns: dict[str, list[str]], +) -> None: + """ + Register kernel fusion patterns for a model class without modifying it directly. + + This is an alternative to setting `_kernel_fusion_patterns` as a class attribute, + useful when the model class is frozen or comes from an external library. + + Args: + model_class_or_instance: + The model class (or an instance of it) for which patterns are being registered. + patterns (`dict[str, list[str]]`): + Mapping from `kernel_layer_name` to a list of glob-style module paths, + identical in format to `_kernel_fusion_patterns`. + """ + if not isinstance(model_class_or_instance, type): + model_class_or_instance = type(model_class_or_instance) + _FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns + + +class FusedModuleBase(nn.Module): + def __init__( + self, + modules_to_fuse: list["nn.Module"], + source_names: list[str], + fused_module_names: list[str] | None = None, + ): + """ + Args: + modules_to_fuse: The source modules to fuse together. + source_names: The attribute names under which each module lives in its parent + (used to restore them on `unfuse_modules`). + fused_module_names: The names under which each source module is registered as a + child of this container (i.e. `self.`). When `None`, the + `kernel_layer_name` attribute of each source module is used. Pass this + explicitly when the source modules do not carry `@use_kernel_forward_from_hub`. + """ + super().__init__() + if len(modules_to_fuse) == 0: + raise ValueError("At least one module must be provided for fusion.") + if len(modules_to_fuse) != len(source_names): + raise ValueError("Length of modules_to_fuse and source_names must match.") + + self._source_names = source_names + + if fused_module_names is not None: + if len(fused_module_names) != len(modules_to_fuse): + raise ValueError("Length of fused_module_names and modules_to_fuse must match.") + for module, name in zip(modules_to_fuse, fused_module_names): + self.add_module(name, module) + self._fused_module_names = list(fused_module_names) + else: + for module in modules_to_fuse: + attr_name = getattr(module, "kernel_layer_name", None) + if attr_name is None: + raise ValueError( + f"Module {module} does not have a 'kernel_layer_name' attribute. " + f"Either decorate it with @use_kernel_forward_from_hub or provide " + f"explicit names via the inline pattern format: " + f'(("", ""), ...).' + ) + self.add_module(attr_name, module) + self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] + + # `kernelize` validates the kernel's forward signature against the class being replaced. + # Since the fused container sits at the position of the first module in the chain, the + # kernel's forward must match that module's signature. We patch the class-level forward + # here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it. + # The body raises because this forward is always replaced by the kernel before any call. + @functools.wraps(type(modules_to_fuse[0]).forward) + def forward(self, *args, **kwargs): + raise NotImplementedError("FusedModule is a placeholder and should not be called directly.") + + self.__class__.forward = forward + + def __repr__(self): + names = ", ".join(self._fused_module_names) + return f"{self.__class__.__name__}(fused=({names}))" + + +@functools.cache +def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type: + """ + Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination. + + Args: + source_layer_names (`tuple[str, ...]`): + Ordered tuple of `kernel_layer_name` values of the modules being fused + (e.g. `("RMSNorm", "MLP")`). Used as the cache key — the same combination + always returns the same class object. + kernel_layer_name (`str`): + The name assigned to the fused class, used by `kernelize` to look up the + kernel in the mapping (e.g. `"RMSNormMLP"`). + + Returns: + A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. + """ + return type( + f"Fused_{'_'.join(source_layer_names)}", + (FusedModuleBase,), + {"kernel_layer_name": kernel_layer_name}, + ) + + +def fuse_modules( + model: "nn.Module", + module_names_to_fuse: list[str], + kernel_layer_name: str, + source_layer_names: list[str] | None = None, +) -> None: + """ + Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. + + For every parent module whose immediate children match all entries in + `module_names_to_fuse`, the function: + + - replaces the first module with a `FusedModuleBase` subclass instance that holds + all source modules as named children, + - replaces the remaining modules with `nn.Identity()` pass-throughs. + + The fused container's `forward` signature is patched to match the first source + module's `forward`, satisfying the `kernelize` signature check. + + Args: + model (`nn.Module`): + The model to modify in-place. + module_names_to_fuse (`list[str]`): + Glob-style paths of the modules to fuse, e.g. + `["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]`. + Integer indices are replaced with `*` so the same pattern applies to + every repeated block. + kernel_layer_name (`str`): + The `kernel_layer_name` assigned to the fused class, used by `kernelize` + to look up the kernel in the mapping (e.g. `"RMSNormMLP"`). + source_layer_names (`list[str]`, *optional*): + Explicit names for the child modules inside the fused container + (e.g. `["RMSNorm", "MLP"]`). When `None`, the `kernel_layer_name` + attribute of each source module is used. + """ + pattern = re.compile(r"\d+") + for module_name, module in model.named_modules(): + generic_children = { + re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) + for n, child in module.named_children() + } + if not all(p in generic_children for p in module_names_to_fuse): + continue + + child_names = [generic_children[p][0] for p in module_names_to_fuse] + modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] + + if source_layer_names is not None: + resolved_names = tuple(source_layer_names) + else: + resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) + FusedClass = make_fused_module_class(resolved_names, kernel_layer_name) + fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names)) + + module.add_module(child_names[0], fused_instance) + for child_name in child_names[1:]: + module.add_module(child_name, nn.Identity()) + + +def unfuse_modules(model: "nn.Module") -> None: + """ + Revert a previous `fuse_modules` call in-place, restoring the original modules. + + For each `FusedModuleBase` instance found in the model tree, the function: + + - restores the original first module at the fused container's position, + - restores the remaining original modules at their original positions + (replacing the `nn.Identity()` pass-throughs). + + Args: + model (`nn.Module`): The model to restore in-place. + """ + for parent in model.modules(): + for name, child in list(parent.named_children()): + if not isinstance(child, FusedModuleBase): + continue + orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names] + parent.add_module(name, orig_modules[0]) + for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]): + parent.add_module(sibling_name, orig_module) + + __all__ = [ + "FusedModuleBase", "LayerRepository", - "use_kernel_forward_from_hub", - "use_kernel_func_from_hub", + "fuse_modules", + "get_kernel", + "lazy_load_kernel", + "make_fused_module_class", + "register_fusion_patterns", "register_kernel_mapping", "register_kernel_mapping_transformers", "replace_kernel_forward_from_hub", - "lazy_load_kernel", - "get_kernel", + "unfuse_modules", + "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "use_kernelized_func", ] # type: ignore diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b041964bbdfc..2b9cd76f4914 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3738,6 +3738,10 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None register_kernel_mapping_transformers() if kernel_config is not None and isinstance(kernel_config, KernelConfig): + # For n-to-1 entries (tuple keys), fuse the corresponding modules in the model + # before validation so that FusedModuleBase instances are visible to sanitize. + kernel_config.apply_fusions(self) + # This will make sure the mapping is valid, and the layers are registered in the model kernel_config.sanitize_kernel_mapping(self) diff --git a/src/transformers/utils/kernel_config.py b/src/transformers/utils/kernel_config.py index bb4f965ddbf4..3a64a3808279 100644 --- a/src/transformers/utils/kernel_config.py +++ b/src/transformers/utils/kernel_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..integrations.hub_kernels import _FUSION_PATTERNS_REGISTRY, fuse_modules from ..utils import PushToHubMixin @@ -116,6 +117,54 @@ def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revi } } + def apply_fusions(self, model): + """ + For each n-to-1 entry (tuple key) in the kernel mapping, find the fusion patterns + registered on the model, fuse the corresponding modules in-place, then replace the + tuple key with the resolved kernel layer name so the rest of the pipeline is unchanged. + """ + new_mapping = {} + for layer_name, kernel in self.kernel_mapping.items(): + if not isinstance(layer_name, tuple): + new_mapping[layer_name] = kernel + continue + + # Parse the target kernel layer name from the repo string (the part after ':') + repo_str = kernel if isinstance(kernel, str) else next(iter(kernel.values())) + if isinstance(repo_str, dict): + repo_str = next(iter(repo_str.values())) + kernel_layer_name = repo_str.split(":")[1] if ":" in repo_str else repo_str.split("/")[-1] + + # Only fuse if a kernel is available for the current device + current_device = infer_device(model) + has_kernel_for_device = isinstance(kernel, str) or current_device in kernel + if not has_kernel_for_device: + continue + + # Detect inline format: tuple of (kernel_layer_name, glob_pattern) pairs. + # e.g. (("RMSNorm", "model.layers.*.post_attention_layernorm"), ("MLP", "model.layers.*.mlp")) + is_inline = all(isinstance(item, tuple) and len(item) == 2 for item in layer_name) + if is_inline: + source_names = [item[0] for item in layer_name] + patterns = [item[1] for item in layer_name] + fuse_modules(model, patterns, kernel_layer_name, source_layer_names=source_names) + else: + fusion_patterns = getattr(model, "_kernel_fusion_patterns", None) or _FUSION_PATTERNS_REGISTRY.get( + type(model), {} + ) + if kernel_layer_name not in fusion_patterns: + raise ValueError( + f"{type(model).__name__} does not define fusion patterns for '{kernel_layer_name}'. " + f'Either add `_kernel_fusion_patterns = {{"{kernel_layer_name}": [...]}}` to the model class, ' + f"call `register_fusion_patterns({type(model).__name__}, ...)` before loading the model, " + f"or use the inline pattern format: " + f'(("{kernel_layer_name}", ""), ...).' + ) + fuse_modules(model, fusion_patterns[kernel_layer_name], kernel_layer_name) + new_mapping[kernel_layer_name] = kernel + + self.kernel_mapping = new_mapping + def store_registered_layer_names(self, model): for name, module in model.named_modules(): if hasattr(module, "kernel_layer_name"):