diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 15c9fa44f5ff..891ac28455af 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -307,6 +307,17 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader == module: if self.group.onload_self: self.group.onload_() + else: + # onload_self=False means this group relies on prefetching from a previous group. + # However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios), + # the prefetch chain may not cover them if they were absent during the first forward pass + # when the execution order was traced. In that case, their weights remain on offload_device, + # so we fall back to a synchronous onload here. + params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters) + if params and params[0].device == self.group.offload_device: + self.group.onload_() + if self.group.stream is not None: + self.group.stream.synchronize() should_onload_next_group = self.next_group is not None and not self.next_group.onload_self if should_onload_next_group: diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 236094109d07..108a7247bcc6 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -566,3 +566,127 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non "layers_per_block": 1, } return init_dict + + +# Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior. +# These modules are only called when optional inputs are provided, which means the lazy prefetch +# execution order tracer may not see them on the first forward pass. This can cause a device mismatch +# on subsequent calls when the modules ARE invoked but their weights were never onloaded. +# See: https://github.com/huggingface/diffusers/pull/13211 +class DummyModelWithConditionalModules(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + # These modules are only invoked when optional_input is not None. + # Output dimension matches hidden_features so they can be added after linear_1. + self.optional_proj_1 = torch.nn.Linear(in_features, hidden_features) + self.optional_proj_2 = torch.nn.Linear(in_features, hidden_features) + + def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + if optional_input is not None: + # Add optional projections after linear_1 so dimensions match (both hidden_features) + x = x + self.optional_proj_1(optional_input) + x = x + self.optional_proj_2(optional_input) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class ConditionalModuleGroupOffloadTests(GroupOffloadTests): + """Tests for conditionally-executed modules under group offloading with streams. + + Regression tests for the case where a module is not executed during the first forward pass + (when the lazy prefetch execution order is traced), but IS executed on subsequent passes. + Without the fix, the weights of such modules remain on CPU while the input is on GPU, + causing a RuntimeError about tensor device mismatch. + """ + + def get_model(self): + torch.manual_seed(0) + return DummyModelWithConditionalModules( + in_features=self.in_features, + hidden_features=self.hidden_features, + out_features=self.out_features, + num_layers=self.num_layers, + ) + + @parameterized.expand([("leaf_level",), ("block_level",)]) + @unittest.skipIf( + torch.device(torch_device).type not in ["cuda", "xpu"], + "Test requires a CUDA or XPU device.", + ) + def test_conditional_modules_with_stream(self, offload_type: str): + """Regression test: conditionally-executed modules must not cause device mismatch when using streams. + + The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only + executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/ + patch_long in HeliosTransformer3DModel, which are only called when history latents are present. + + When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first + forward pass and sets up a prefetch chain so each module pre-loads the next one's weights. + Modules not executed during this tracing pass are excluded from the prefetch chain. + + The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set + to False (meaning "someone else will onload me"). But since it's not in the prefetch chain, + nobody ever does — so its weights remain on CPU. When the module is eventually called in a + subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError. + + We therefore must invoke the model multiple times: + 1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2 + are absent, so they are excluded from the prefetch chain. + 2. Second pass WITH optional_input: the regression case. Without the fix, this raises a + RuntimeError because optional_proj_1/2 weights are still on CPU. + 3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen + both code paths. + """ + + model = self.get_model() + model_ref = self.get_model() + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload( + torch_device, + offload_type=offload_type, + num_blocks_per_group=1, + use_stream=True, + ) + + x = torch.randn(4, self.in_features).to(torch_device) + optional_input = torch.randn(4, self.in_features).to(torch_device) + + with torch.no_grad(): + # First forward pass WITHOUT optional_input — this is when the lazy prefetch + # execution order is traced. optional_proj_1/2 are NOT in the traced order. + out_ref_no_opt = model_ref(x, optional_input=None) + out_no_opt = model(x, optional_input=None) + self.assertTrue( + torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), + f"[{offload_type}] Outputs do not match on first pass (no optional_input).", + ) + + # Second forward pass WITH optional_input — optional_proj_1/2 ARE now called. + out_ref_with_opt = model_ref(x, optional_input=optional_input) + out_with_opt = model(x, optional_input=optional_input) + self.assertTrue( + torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), + f"[{offload_type}] Outputs do not match on second pass (with optional_input).", + ) + + # Third pass again without optional_input — verify stable behavior. + out_ref_no_opt2 = model_ref(x, optional_input=None) + out_no_opt2 = model(x, optional_input=None) + self.assertTrue( + torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), + f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).", + )