-
Notifications
You must be signed in to change notification settings - Fork 33.1k
n-to-1 kernel fusion via KernelConfig
#45363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b387190
6bc9402
62d4454
4082fe1
ac4a699
e13111f
d9d53f0
e1c7f3f
db0b7f0
e21d06e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import copy | ||
| import torch | ||
|
|
||
| from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig | ||
| from transformers.integrations import unfuse_modules | ||
|
|
||
|
|
||
| model_id = "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| kernel_config = KernelConfig({ | ||
| ( | ||
| ("RMSNorm", "model.layers.*.post_attention_layernorm"), | ||
| ("MLP", "model.layers.*.mlp"), | ||
| ): "michaelbenayoun/dummy-rmsnorm-mlp:RMSNormMLP", | ||
| }) | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_id, use_kernels=True, kernel_config=kernel_config, device_map="cuda") | ||
|
|
||
| input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device) | ||
|
|
||
| original_model = copy.deepcopy(model) | ||
| unfuse_modules(original_model) | ||
| original_model.eval() | ||
|
|
||
| with torch.no_grad(): | ||
| fused_out = model(input_ids).logits | ||
| original_out = original_model(input_ids).logits | ||
|
|
||
| print("Max diff fused vs original:", (fused_out - original_out).abs().max().item()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import functools | ||
| import importlib.metadata | ||
| import os | ||
| import re | ||
|
|
@@ -21,10 +22,14 @@ | |
| from packaging import version as pkg_version | ||
|
|
||
| from ..utils import ENV_VARS_TRUE_VALUES, logging | ||
| from ..utils.import_utils import is_kernels_available | ||
| from ..utils.import_utils import is_kernels_available, is_torch_available | ||
| from .flash_attention import flash_attention_forward | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| try: | ||
|
|
@@ -500,14 +505,212 @@ def allow_all_hub_kernels(): | |
| ALLOW_ALL_KERNELS = False | ||
|
|
||
|
|
||
| # Model class → {kernel_layer_name → [glob patterns]} | ||
| # Populated via `register_fusion_patterns` for models that cannot be modified directly. | ||
| _FUSION_PATTERNS_REGISTRY: dict[type, dict[str, list[str]]] = {} | ||
|
|
||
|
|
||
| def register_fusion_patterns( | ||
| model_class_or_instance, | ||
| patterns: dict[str, list[str]], | ||
| ) -> None: | ||
| """ | ||
| Register kernel fusion patterns for a model class without modifying it directly. | ||
|
|
||
| This is an alternative to setting `_kernel_fusion_patterns` as a class attribute, | ||
| useful when the model class is frozen or comes from an external library. | ||
|
|
||
| Args: | ||
| model_class_or_instance: | ||
| The model class (or an instance of it) for which patterns are being registered. | ||
| patterns (`dict[str, list[str]]`): | ||
| Mapping from `kernel_layer_name` to a list of glob-style module paths, | ||
| identical in format to `_kernel_fusion_patterns`. | ||
| """ | ||
| if not isinstance(model_class_or_instance, type): | ||
| model_class_or_instance = type(model_class_or_instance) | ||
| _FUSION_PATTERNS_REGISTRY[model_class_or_instance] = patterns | ||
|
|
||
|
|
||
| class FusedModuleBase(nn.Module): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ArthurZucker the API is way simpler than #44979 because the whole purpose of fusion here is to replace the forward with a kernel. So we do not need all the complex machinery. What we need is just a module and being able to replace its forward with the kernel forward. |
||
| def __init__( | ||
| self, | ||
| modules_to_fuse: list["nn.Module"], | ||
| source_names: list[str], | ||
| fused_module_names: list[str] | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| modules_to_fuse: The source modules to fuse together. | ||
| source_names: The attribute names under which each module lives in its parent | ||
| (used to restore them on `unfuse_modules`). | ||
| fused_module_names: The names under which each source module is registered as a | ||
| child of this container (i.e. `self.<name>`). When `None`, the | ||
| `kernel_layer_name` attribute of each source module is used. Pass this | ||
| explicitly when the source modules do not carry `@use_kernel_forward_from_hub`. | ||
| """ | ||
| super().__init__() | ||
| if len(modules_to_fuse) == 0: | ||
| raise ValueError("At least one module must be provided for fusion.") | ||
| if len(modules_to_fuse) != len(source_names): | ||
| raise ValueError("Length of modules_to_fuse and source_names must match.") | ||
|
|
||
| self._source_names = source_names | ||
|
|
||
| if fused_module_names is not None: | ||
| if len(fused_module_names) != len(modules_to_fuse): | ||
| raise ValueError("Length of fused_module_names and modules_to_fuse must match.") | ||
| for module, name in zip(modules_to_fuse, fused_module_names): | ||
| self.add_module(name, module) | ||
| self._fused_module_names = list(fused_module_names) | ||
| else: | ||
| for module in modules_to_fuse: | ||
| attr_name = getattr(module, "kernel_layer_name", None) | ||
| if attr_name is None: | ||
| raise ValueError( | ||
| f"Module {module} does not have a 'kernel_layer_name' attribute. " | ||
| f"Either decorate it with @use_kernel_forward_from_hub or provide " | ||
| f"explicit names via the inline pattern format: " | ||
| f'(("<name>", "<glob_path>"), ...).' | ||
| ) | ||
| self.add_module(attr_name, module) | ||
| self._fused_module_names = [m.kernel_layer_name for m in modules_to_fuse] | ||
|
|
||
| # `kernelize` validates the kernel's forward signature against the class being replaced. | ||
| # Since the fused container sits at the position of the first module in the chain, the | ||
| # kernel's forward must match that module's signature. We patch the class-level forward | ||
| # here (via `functools.wraps`) so the signature is correct when `kernelize` inspects it. | ||
| # The body raises because this forward is always replaced by the kernel before any call. | ||
| @functools.wraps(type(modules_to_fuse[0]).forward) | ||
| def forward(self, *args, **kwargs): | ||
| raise NotImplementedError("FusedModule is a placeholder and should not be called directly.") | ||
|
|
||
| self.__class__.forward = forward | ||
|
|
||
| def __repr__(self): | ||
| names = ", ".join(self._fused_module_names) | ||
| return f"{self.__class__.__name__}(fused=({names}))" | ||
|
|
||
|
|
||
| @functools.cache | ||
| def make_fused_module_class(source_layer_names: tuple[str, ...], kernel_layer_name: str) -> type: | ||
| """ | ||
| Dynamically create and cache a `FusedModuleBase` subclass for a given fusion combination. | ||
|
|
||
| Args: | ||
| source_layer_names (`tuple[str, ...]`): | ||
| Ordered tuple of `kernel_layer_name` values of the modules being fused | ||
| (e.g. `("RMSNorm", "MLP")`). Used as the cache key — the same combination | ||
| always returns the same class object. | ||
| kernel_layer_name (`str`): | ||
| The name assigned to the fused class, used by `kernelize` to look up the | ||
| kernel in the mapping (e.g. `"RMSNormMLP"`). | ||
|
|
||
| Returns: | ||
| A subclass of `FusedModuleBase` with `kernel_layer_name` set as a class attribute. | ||
| """ | ||
| return type( | ||
| f"Fused_{'_'.join(source_layer_names)}", | ||
| (FusedModuleBase,), | ||
| {"kernel_layer_name": kernel_layer_name}, | ||
| ) | ||
|
|
||
|
|
||
| def fuse_modules( | ||
| model: "nn.Module", | ||
| module_names_to_fuse: list[str], | ||
| kernel_layer_name: str, | ||
| source_layer_names: list[str] | None = None, | ||
| ) -> None: | ||
| """ | ||
| Fuse a sequence of submodules into a single `FusedModuleBase` subclass in-place. | ||
|
|
||
| For every parent module whose immediate children match all entries in | ||
| `module_names_to_fuse`, the function: | ||
|
|
||
| - replaces the first module with a `FusedModuleBase` subclass instance that holds | ||
| all source modules as named children, | ||
| - replaces the remaining modules with `nn.Identity()` pass-throughs. | ||
|
|
||
| The fused container's `forward` signature is patched to match the first source | ||
| module's `forward`, satisfying the `kernelize` signature check. | ||
|
|
||
| Args: | ||
| model (`nn.Module`): | ||
| The model to modify in-place. | ||
| module_names_to_fuse (`list[str]`): | ||
| Glob-style paths of the modules to fuse, e.g. | ||
| `["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]`. | ||
| Integer indices are replaced with `*` so the same pattern applies to | ||
| every repeated block. | ||
| kernel_layer_name (`str`): | ||
| The `kernel_layer_name` assigned to the fused class, used by `kernelize` | ||
| to look up the kernel in the mapping (e.g. `"RMSNormMLP"`). | ||
| source_layer_names (`list[str]`, *optional*): | ||
| Explicit names for the child modules inside the fused container | ||
| (e.g. `["RMSNorm", "MLP"]`). When `None`, the `kernel_layer_name` | ||
| attribute of each source module is used. | ||
| """ | ||
| pattern = re.compile(r"\d+") | ||
| for module_name, module in model.named_modules(): | ||
| generic_children = { | ||
| re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) | ||
| for n, child in module.named_children() | ||
| } | ||
| if not all(p in generic_children for p in module_names_to_fuse): | ||
| continue | ||
|
|
||
| child_names = [generic_children[p][0] for p in module_names_to_fuse] | ||
| modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] | ||
|
|
||
| if source_layer_names is not None: | ||
| resolved_names = tuple(source_layer_names) | ||
| else: | ||
| resolved_names = tuple(getattr(m, "kernel_layer_name") for m in modules_to_fuse) | ||
| FusedClass = make_fused_module_class(resolved_names, kernel_layer_name) | ||
| fused_instance = FusedClass(modules_to_fuse, child_names, fused_module_names=list(resolved_names)) | ||
|
|
||
| module.add_module(child_names[0], fused_instance) | ||
| for child_name in child_names[1:]: | ||
| module.add_module(child_name, nn.Identity()) | ||
|
Comment on lines
+673
to
+675
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we're probably gonna have concerns with hooks, especially with TP potentially but also accelerate! |
||
|
|
||
|
|
||
| def unfuse_modules(model: "nn.Module") -> None: | ||
| """ | ||
| Revert a previous `fuse_modules` call in-place, restoring the original modules. | ||
|
|
||
| For each `FusedModuleBase` instance found in the model tree, the function: | ||
|
|
||
| - restores the original first module at the fused container's position, | ||
| - restores the remaining original modules at their original positions | ||
| (replacing the `nn.Identity()` pass-throughs). | ||
|
|
||
| Args: | ||
| model (`nn.Module`): The model to restore in-place. | ||
| """ | ||
| for parent in model.modules(): | ||
| for name, child in list(parent.named_children()): | ||
| if not isinstance(child, FusedModuleBase): | ||
| continue | ||
| orig_modules = [getattr(child, layer_name) for layer_name in child._fused_module_names] | ||
| parent.add_module(name, orig_modules[0]) | ||
| for sibling_name, orig_module in zip(child._source_names[1:], orig_modules[1:]): | ||
| parent.add_module(sibling_name, orig_module) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "FusedModuleBase", | ||
| "LayerRepository", | ||
| "use_kernel_forward_from_hub", | ||
| "use_kernel_func_from_hub", | ||
| "fuse_modules", | ||
| "get_kernel", | ||
| "lazy_load_kernel", | ||
| "make_fused_module_class", | ||
| "register_fusion_patterns", | ||
| "register_kernel_mapping", | ||
| "register_kernel_mapping_transformers", | ||
| "replace_kernel_forward_from_hub", | ||
| "lazy_load_kernel", | ||
| "get_kernel", | ||
| "unfuse_modules", | ||
| "use_kernel_forward_from_hub", | ||
| "use_kernel_func_from_hub", | ||
| "use_kernelized_func", | ||
| ] # type: ignore | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be removed before merging. Temporary file.