From 1032a255fb42f8e2856bff31ae9e0208e4c35996 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 24 Mar 2026 18:32:39 -0400 Subject: [PATCH 1/4] wip: add module_fusion module --- src/transformers/module_fusion.py | 218 ++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 src/transformers/module_fusion.py diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py new file mode 100644 index 000000000000..a7c991665a10 --- /dev/null +++ b/src/transformers/module_fusion.py @@ -0,0 +1,218 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from dataclasses import dataclass +from typing import Any + +from .utils import is_torch_available + + +if is_torch_available(): + import torch.nn as nn + + +@dataclass +class ModuleSpec: + """ + Describes the input and output tensor names for one module in a fusion chain. + + Args: + inputs (`list[str]`): + Names of the positional inputs to this module, in order. + outputs (`list[str]`): + Names of the outputs produced by this module, in order. + Use ``"_"`` as a placeholder for outputs that should be ignored / not wired. + """ + + inputs: list[str] + outputs: list[str] + + +class RegistryCollector(nn.Module): + """ + Transparent pass-through module that captures its inputs into a shared registry. + + Placed at the position of every module in the fusion chain except the last one. + The registry is later consumed by `FusedModule`. + """ + + def __init__(self, spec: ModuleSpec, index: int, registry: dict[str, Any]): + super().__init__() + self.spec = spec + self.index = index + self._registry = registry + + def _input_key(self, name: str) -> str: + return f"in_{self.index}_{name}" + + def forward(self, *args, **kwargs): + for name, arg in zip(self.spec.inputs, args): + self._registry[self._input_key(name)] = arg + self._registry.update({self._input_key(name): value for name, value in kwargs.items()}) + return args[0] if len(args) == 1 else args + + +class FusedModule(nn.Module): + """ + Executes a chain of modules in a single forward call, using inputs previously + captured by `RegistryCollector` instances. + + The registry uses two namespaces: + - ``in_{i}_{name}`` — external inputs captured by collector ``i`` + - ``out_{i}_{name}`` — outputs produced by module ``i`` during fused execution + + For module ``i > 0``, inputs are resolved from ``out_{i-1}_{name}`` first, + then fall back to ``in_{i}_{name}`` for external inputs not produced by the chain. + """ + + def __init__(self, modules: list[nn.Module], specs: list[ModuleSpec], registry: dict[str, Any]): + super().__init__() + self.modules_to_fuse = nn.ModuleList(modules) + self.specs = specs + self._registry = registry + self._signatures = [inspect.signature(mod.forward) for mod in modules] + self._validate_specs() + + def _input_key(self, module_index: int, name: str) -> str: + return f"in_{module_index}_{name}" + + def _output_key(self, module_index: int, name: str) -> str: + return f"out_{module_index}_{name}" + + def _validate_specs(self): + if len(self.modules_to_fuse) != len(self.specs): + raise ValueError("Number of modules and specs must match.") + + # Build a mapping: out_{i}_{name} → i, for each module output. + output_producers = {} + for i, spec in enumerate(self.specs): + for name in spec.outputs: + if name != "_": + output_producers[self._output_key(i, name)] = i + + for i, (mod, spec, sig) in enumerate(zip(self.modules_to_fuse, self.specs, self._signatures)): + if len(spec.inputs) != len(sig.parameters): + raise ValueError( + f"Module of type {type(mod)} expects {len(sig.parameters)} inputs " + f"but spec defines {len(spec.inputs)}." + ) + if i == 0: + continue # module 0 inputs come from collectors, always externally provided + for inp in spec.inputs: + key = self._output_key(i - 1, inp) + if key in output_producers and output_producers[key] > i - 1: + raise ValueError( + f"Module {i} requires '{inp}' but it is produced by module " + f"{output_producers[key]}, which comes later in the chain." + ) + + def forward(self, *args, **kwargs): + for name, arg in zip(self.specs[0].inputs, args): + self._registry[self._input_key(0, name)] = arg + self._registry.update({self._input_key(0, name): value for name, value in kwargs.items()}) + + outputs = None + for index, (mod, spec, sig) in enumerate(zip(self.modules_to_fuse, self.specs, self._signatures)): + param_names = list(sig.parameters.keys()) + inputs = {} + for spec_name, arg_name in zip(spec.inputs, param_names): + if index == 0: + key = self._input_key(0, spec_name) + else: + out_key = self._output_key(index - 1, spec_name) + key = out_key if out_key in self._registry else self._input_key(index, spec_name) + inputs[arg_name] = self._registry[key] + bound = sig.bind(**inputs) + bound.apply_defaults() + outputs = mod(**bound.arguments) + if not isinstance(outputs, tuple): + outputs = (outputs,) + self._registry.update( + {self._output_key(index, name): output for name, output in zip(spec.outputs, outputs) if name != "_"} + ) + + self._registry.clear() + + if outputs is None: + return None + return outputs[0] if len(outputs) == 1 else outputs + + def __repr__(self): + return f"FusedModule(fused={self.modules_to_fuse})" + + +def fuse_modules( + model: nn.Module, + module_names_to_fuse: list[str], + module_specs: list[ModuleSpec], +) -> None: + """ + Fuse a sequence of submodules into a single `FusedModule` in-place. + + The function traverses the model tree and, for every parent module whose + immediate children match all entries in ``module_names_to_fuse``, replaces: + + - every module except the last with a `RegistryCollector` (transparent pass-through), + - the last module with a `FusedModule` that re-executes the full chain. + + Args: + model (`nn.Module`): + The model to modify in-place. + module_names_to_fuse (`list[str]`): + Glob-style paths of the modules to fuse, e.g. + ``["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"]``. + Integer indices are replaced with ``*`` during matching so that the + same spec applies to every repeated block. + module_specs (`list[ModuleSpec]`): + One `ModuleSpec` per entry in `module_names_to_fuse`, + describing input/output tensor names for each module. + + Example: + + specs = [ + ModuleSpec(inputs=["hidden_states"], outputs=["hidden_states"]), + ModuleSpec(inputs=["hidden_states"], outputs=["hidden_states"]), + ] + fuse_modules( + model, + ["model.layers.*.post_attention_layernorm", "model.layers.*.mlp"], + specs, + ) + """ + pattern = re.compile(r"\d+") + to_visit = [(model, "")] + + while to_visit: + parent, parent_name = to_visit.pop() + named_children = {} + for name, child in parent.named_children(): + full_name = f"{parent_name}.{name}" if parent_name else name + generic_name = re.sub(pattern, "*", full_name) + named_children[generic_name] = {"module": child, "name": name} + + if all(name in named_children for name in module_names_to_fuse): + registry = {} + modules_to_fuse = [named_children[name]["module"] for name in module_names_to_fuse] + + for index, (name, spec) in enumerate(zip(module_names_to_fuse[:-1], module_specs[:-1])): + attr_name = named_children[name]["name"] + parent.add_module(attr_name, RegistryCollector(spec, index, registry)) + + last_name = module_names_to_fuse[-1] + parent.add_module(named_children[last_name]["name"], FusedModule(modules_to_fuse, module_specs, registry)) + + for entry in named_children.values(): + to_visit.append((entry["module"], re.sub(pattern, "*", f"{parent_name}.{entry['name']}" if parent_name else entry["name"]))) From f81f990099d4752885816bf0fc62287ee32faf72 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 25 Mar 2026 15:29:36 -0400 Subject: [PATCH 2/4] fix: issue with traversal and add unfuse --- src/transformers/module_fusion.py | 67 ++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py index a7c991665a10..126cec16ae3f 100644 --- a/src/transformers/module_fusion.py +++ b/src/transformers/module_fusion.py @@ -104,10 +104,11 @@ def _validate_specs(self): output_producers[self._output_key(i, name)] = i for i, (mod, spec, sig) in enumerate(zip(self.modules_to_fuse, self.specs, self._signatures)): - if len(spec.inputs) != len(sig.parameters): + n_required = sum(1 for p in sig.parameters.values() if p.default is inspect.Parameter.empty) + if len(spec.inputs) < n_required or len(spec.inputs) > len(sig.parameters): raise ValueError( f"Module of type {type(mod)} expects {len(sig.parameters)} inputs " - f"but spec defines {len(spec.inputs)}." + f"({n_required} required) but spec defines {len(spec.inputs)}." ) if i == 0: continue # module 0 inputs come from collectors, always externally provided @@ -193,26 +194,52 @@ def fuse_modules( ) """ pattern = re.compile(r"\d+") - to_visit = [(model, "")] + for module_name, module in model.named_modules(): + generic_children = { + re.sub(pattern, "*", f"{module_name}.{n}" if module_name else n): (n, child) + for n, child in module.named_children() + } + if not all(p in generic_children for p in module_names_to_fuse): + continue + registry = {} + modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] + for index, (p, spec) in enumerate(zip(module_names_to_fuse[:-1], module_specs[:-1])): + module.add_module(generic_children[p][0], RegistryCollector(spec, index, registry)) + last_p = module_names_to_fuse[-1] + module.add_module(generic_children[last_p][0], FusedModule(modules_to_fuse, module_specs, registry)) + + +def unfuse_modules(model: nn.Module) -> None: + """ + Revert a previous `fuse_modules` call in-place, restoring the original modules. + + For each `FusedModule` found in the model tree, the function: - while to_visit: - parent, parent_name = to_visit.pop() - named_children = {} - for name, child in parent.named_children(): - full_name = f"{parent_name}.{name}" if parent_name else name - generic_name = re.sub(pattern, "*", full_name) - named_children[generic_name] = {"module": child, "name": name} + - replaces each sibling `RegistryCollector` with the corresponding original module + (recovered from `FusedModule.modules_to_fuse`), + - replaces the `FusedModule` itself with the last original module. - if all(name in named_children for name in module_names_to_fuse): - registry = {} - modules_to_fuse = [named_children[name]["module"] for name in module_names_to_fuse] + Collectors belonging to a given `FusedModule` are identified by sharing the same + ``_registry`` object. - for index, (name, spec) in enumerate(zip(module_names_to_fuse[:-1], module_specs[:-1])): - attr_name = named_children[name]["name"] - parent.add_module(attr_name, RegistryCollector(spec, index, registry)) + Args: + model (`nn.Module`): The model to restore in-place. - last_name = module_names_to_fuse[-1] - parent.add_module(named_children[last_name]["name"], FusedModule(modules_to_fuse, module_specs, registry)) + Example:: - for entry in named_children.values(): - to_visit.append((entry["module"], re.sub(pattern, "*", f"{parent_name}.{entry['name']}" if parent_name else entry["name"]))) + fuse_modules(model, ["model.layers.*.norm", "model.layers.*.mlp"], specs) + # ... optimized forward pass ... + unfuse_modules(model) # back to original + """ + for parent in model.modules(): + fused_children = {name: child for name, child in parent.named_children() if isinstance(child, FusedModule)} + for fused_name, fused in fused_children.items(): + # Collectors belonging to this FusedModule share the same registry object. + collectors = { + name: child + for name, child in parent.named_children() + if isinstance(child, RegistryCollector) and child._registry is fused._registry + } + for col_name, collector in collectors.items(): + parent.add_module(col_name, fused.modules_to_fuse[collector.index]) + parent.add_module(fused_name, fused.modules_to_fuse[-1]) From dd511f4f81d806e97e08ab281a02bfc2b9bd95ed Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 25 Mar 2026 15:29:50 -0400 Subject: [PATCH 3/4] test: add tests for module fusion --- tests/test_module_fusion.py | 210 ++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 tests/test_module_fusion.py diff --git a/tests/test_module_fusion.py b/tests/test_module_fusion.py new file mode 100644 index 000000000000..e76ee42ebf8e --- /dev/null +++ b/tests/test_module_fusion.py @@ -0,0 +1,210 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +from transformers import is_torch_available +from transformers.module_fusion import FusedModule, ModuleSpec, RegistryCollector, fuse_modules, unfuse_modules +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + import torch.nn as nn + + +class LinearWithScale(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + + def forward(self, x, scale=1.0): + return self.linear(x) * scale + + +class LayerNorm(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(8) + + def forward(self, x, bias=None): + x = self.norm(x) + if bias is not None: + x = x + bias + return x + + +class DummyBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = LinearWithScale() + self.norm = LayerNorm() + + def forward(self, x, scale=1.0): + x = self.linear(x, scale=scale) + x = self.norm(x) + return x + + +class DummyModel(nn.Module): + def __init__(self, num_layers=2): + super().__init__() + self.layers = nn.ModuleList([DummyBlock() for _ in range(num_layers)]) + + def forward(self, x, scale=1.0): + for block in self.layers: + x = block(x, scale=scale) + return x + + +@require_torch +class TestModuleFusion(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + self.x = torch.randn(2, 8) + + # --- RegistryCollector --- + + def test_collector_captures_input_and_is_passthrough(self): + """RegistryCollector stores args in registry and returns input unchanged.""" + registry = {} + spec = ModuleSpec(inputs=["x"], outputs=["x"]) + collector = RegistryCollector(spec, index=2, registry=registry) + x = torch.randn(2, 8) + out = collector(x) + self.assertIn("in_2_x", registry) + self.assertIs(registry["in_2_x"], x) + self.assertTrue(torch.equal(out, x)) + + # --- FusedModule validation --- + + def test_fused_module_raises_on_spec_count_mismatch(self): + """Mismatched number of modules vs specs raises ValueError.""" + with self.assertRaises(ValueError): + FusedModule( + [LinearWithScale(), LayerNorm()], + [ModuleSpec(inputs=["x", "scale"], outputs=["x"])], # only 1 spec for 2 modules + {}, + ) + + def test_fused_module_raises_on_input_count_mismatch(self): + """Spec with more inputs than module params raises ValueError.""" + with self.assertRaises(ValueError): + FusedModule( + [LinearWithScale()], # forward(x, scale) → 2 params total + [ModuleSpec(inputs=["x", "scale", "extra"], outputs=["x"])], # 3 > 2 params + {}, + ) + + # --- FusedModule forward --- + + def test_fused_module_chains_outputs_to_next_inputs(self): + """Output of module 0 is passed as input to module 1 via registry.""" + linear = LinearWithScale() + norm = LayerNorm() + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x", "bias"], outputs=["x"]), + ] + # bias is not produced by linear, must be pre-populated in the registry + registry = {"in_1_bias": None} + fused = FusedModule([linear, norm], specs, registry) + out = fused(self.x, torch.tensor(1.0)) + expected = norm(linear(self.x, scale=1.0)) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + + def test_fused_module_fallback_for_external_inputs(self): + """Input not produced by prior module falls back to collector-captured in_{i}_{name}.""" + linear = LinearWithScale() + norm = LayerNorm() + bias = torch.ones(8) * 0.5 + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x", "bias"], outputs=["x"]), + ] + registry = {"in_1_bias": bias} + fused = FusedModule([linear, norm], specs, registry) + out = fused(self.x, torch.tensor(1.0)) + expected = norm(linear(self.x, scale=1.0), bias=bias) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + + def test_fused_module_registry_cleared_after_forward(self): + """Registry is empty after FusedModule.forward() so no state leaks between calls.""" + linear = LinearWithScale() + registry = {} + fused = FusedModule([linear], [ModuleSpec(inputs=["x", "scale"], outputs=["x"])], registry) + fused(self.x, torch.tensor(1.0)) + self.assertEqual(len(registry), 0) + + # --- fuse_modules --- + + def test_fuse_modules_structure(self): + """fuse_modules places RegistryCollectors on all but last, FusedModule on last.""" + model = DummyModel(num_layers=1) + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x"], outputs=["x"]), # bias has a default, omitted from spec + ] + fuse_modules(model, ["layers.*.linear", "layers.*.norm"], specs) + self.assertIsInstance(model.layers[0].linear, RegistryCollector) + self.assertIsInstance(model.layers[0].norm, FusedModule) + + def test_fuse_modules_numerical_equivalence(self): + """Fused model produces identical output to original for all layers.""" + model = DummyModel(num_layers=3) + original = copy.deepcopy(model) + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x"], outputs=["x"]), # bias has a default, omitted from spec + ] + fuse_modules(model, ["layers.*.linear", "layers.*.norm"], specs) + with torch.no_grad(): + self.assertTrue(torch.allclose(model(self.x), original(self.x), atol=1e-6)) + + def test_fuse_modules_each_layer_has_independent_registry(self): + """Each fused group uses its own registry; collector and FusedModule share it.""" + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x"], outputs=["x"]), + ] + num_layers = 5 + model = DummyModel(num_layers=num_layers) + fuse_modules(model, ["layers.*.linear", "layers.*.norm"], specs) + registries = [] + for index in range(num_layers): + block = model.layers[index] + registries.append(block.linear._registry) + + # All layers should have distinct registry objects + registries_ids = {id(reg) for reg in registries} + self.assertEqual(len(registries_ids), num_layers) + + # --- unfuse_modules --- + + def test_unfuse_restores_modules_and_numerical_equivalence(self): + """After fuse+unfuse, original modules are restored and output matches original.""" + model = DummyModel(num_layers=2) + original = copy.deepcopy(model) + orig_linear = model.layers[0].linear + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x"], outputs=["x"]), + ] + fuse_modules(model, ["layers.*.linear", "layers.*.norm"], specs) + unfuse_modules(model) + self.assertIs(model.layers[0].linear, orig_linear) + self.assertNotIsInstance(model.layers[0].norm, FusedModule) + with torch.no_grad(): + self.assertTrue(torch.allclose(model(self.x), original(self.x), atol=1e-6)) From 8ad1a5d3c66df4041daea054a0ed197185858aab Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 25 Mar 2026 16:40:14 -0400 Subject: [PATCH 4/4] feat: add attribute pass-throught --- src/transformers/module_fusion.py | 19 +++++++++++++++++-- tests/test_module_fusion.py | 24 +++++++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/transformers/module_fusion.py b/src/transformers/module_fusion.py index 126cec16ae3f..66ee719fa12b 100644 --- a/src/transformers/module_fusion.py +++ b/src/transformers/module_fusion.py @@ -49,12 +49,21 @@ class RegistryCollector(nn.Module): The registry is later consumed by `FusedModule`. """ - def __init__(self, spec: ModuleSpec, index: int, registry: dict[str, Any]): + def __init__(self, spec: ModuleSpec, index: int, registry: dict[str, Any], orig_module: nn.Module): super().__init__() self.spec = spec self.index = index + self.orig_module = orig_module self._registry = registry + def __getattr__(self, name: str) -> Any: + # This module is a transparent pass-through, so we delegate all attribute access to the original module. + # In particular, it allows to access the original parameters and buffers as if the collector was not there. + try: + return super().__getattr__(name) # handles _modules, _parameters, _buffers + except AttributeError: + return getattr(self._modules["orig_module"], name) + def _input_key(self, name: str) -> str: return f"in_{self.index}_{name}" @@ -86,6 +95,12 @@ def __init__(self, modules: list[nn.Module], specs: list[ModuleSpec], registry: self._signatures = [inspect.signature(mod.forward) for mod in modules] self._validate_specs() + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) # handles _modules, _parameters, _buffers + except AttributeError: + return getattr(self._modules["modules_to_fuse"][-1], name) + def _input_key(self, module_index: int, name: str) -> str: return f"in_{module_index}_{name}" @@ -204,7 +219,7 @@ def fuse_modules( registry = {} modules_to_fuse = [generic_children[p][1] for p in module_names_to_fuse] for index, (p, spec) in enumerate(zip(module_names_to_fuse[:-1], module_specs[:-1])): - module.add_module(generic_children[p][0], RegistryCollector(spec, index, registry)) + module.add_module(generic_children[p][0], RegistryCollector(spec, index, registry, modules_to_fuse[index])) last_p = module_names_to_fuse[-1] module.add_module(generic_children[last_p][0], FusedModule(modules_to_fuse, module_specs, registry)) diff --git a/tests/test_module_fusion.py b/tests/test_module_fusion.py index e76ee42ebf8e..ea67d6bbd3df 100644 --- a/tests/test_module_fusion.py +++ b/tests/test_module_fusion.py @@ -81,13 +81,23 @@ def test_collector_captures_input_and_is_passthrough(self): """RegistryCollector stores args in registry and returns input unchanged.""" registry = {} spec = ModuleSpec(inputs=["x"], outputs=["x"]) - collector = RegistryCollector(spec, index=2, registry=registry) + linear = LinearWithScale() + collector = RegistryCollector(spec, index=2, registry=registry, orig_module=linear) x = torch.randn(2, 8) out = collector(x) self.assertIn("in_2_x", registry) self.assertIs(registry["in_2_x"], x) self.assertTrue(torch.equal(out, x)) + def test_collector_delegates_attribute_access_to_orig_module(self): + """Attribute access on RegistryCollector is transparently forwarded to orig_module.""" + linear = LinearWithScale() + collector = RegistryCollector( + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), index=0, registry={}, orig_module=linear + ) + self.assertIs(collector.linear, linear.linear) + self.assertIs(collector.linear.weight, linear.linear.weight) + # --- FusedModule validation --- def test_fused_module_raises_on_spec_count_mismatch(self): @@ -108,6 +118,18 @@ def test_fused_module_raises_on_input_count_mismatch(self): {}, ) + def test_fused_module_delegates_attribute_access_to_last_module(self): + """Attribute access on FusedModule is transparently forwarded to the last module in the chain.""" + linear = LinearWithScale() + norm = LayerNorm() + specs = [ + ModuleSpec(inputs=["x", "scale"], outputs=["x"]), + ModuleSpec(inputs=["x"], outputs=["x"]), + ] + fused = FusedModule([linear, norm], specs, {}) + self.assertIs(fused.norm, norm.norm) + self.assertIs(fused.norm.weight, norm.norm.weight) + # --- FusedModule forward --- def test_fused_module_chains_outputs_to_next_inputs(self):