diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 9f3ef047d98d..6d81f8a13af7 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -556,14 +556,21 @@ class HeliosTransformer3DModel( _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["HeliosTransformerBlock"] _cp_plan = { - "blocks.0": { + # Input split at attn level and ffn level. + "blocks.*.attn1": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, - "blocks.*": { - "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False), "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3), + "blocks.*.attn2": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*.ffn": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + # Output gather at attn level and ffn level. + **{f"blocks.{i}.attn1": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.attn2": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.ffn": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, } @register_to_config diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 0e08b2c6e958..d8f317a9a6f1 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -449,7 +449,14 @@ def sample_block_noise( width, patch_size: tuple[int, ...] = (1, 2, 2), device: torch.device | None = None, + generator: torch.Generator | None = None, ): + # NOTE: A generator must be provided to ensure correct and reproducible results. + # Creating a default generator here is a fallback only — without a fixed seed, + # the output will be non-deterministic and may produce incorrect results in CP context. + if generator is None: + generator = torch.Generator(device=device) + gamma = self.scheduler.config.gamma _, ph, pw = patch_size block_size = ph * pw @@ -458,13 +465,17 @@ def sample_block_noise( torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma ) - cov += torch.eye(block_size, device=device) * 1e-6 - dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) + cov += torch.eye(block_size, device=device) * 1e-8 + cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. + + L = torch.linalg.cholesky(cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + z = torch.randn(block_number, block_size, device=device, generator=generator) + noise = z @ L.T - noise = dist.sample((block_number,)) # [block number, block_size] noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise @property @@ -918,7 +929,14 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape noise = self.sample_block_noise( - batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device + batch_size, + channel, + num_frames, + pyramid_height, + pyramid_width, + patch_size, + device, + generator, ) noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact