From 1f60b1932f515781641c2706608ab8aa02a0c54c Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 8 Mar 2026 07:55:20 +0000 Subject: [PATCH 1/3] fix Helios Context Parallelism --- .../models/transformers/transformer_helios.py | 17 ++++++++++++----- .../pipelines/helios/pipeline_helios_pyramid.py | 2 ++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 9f3ef047d98d..6d81f8a13af7 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -556,14 +556,21 @@ class HeliosTransformer3DModel( _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["HeliosTransformerBlock"] _cp_plan = { - "blocks.0": { + # Input split at attn level and ffn level. + "blocks.*.attn1": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, - "blocks.*": { - "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False), "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3), + "blocks.*.attn2": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*.ffn": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + # Output gather at attn level and ffn level. + **{f"blocks.{i}.attn1": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.attn2": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.ffn": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, } @register_to_config diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 0e08b2c6e958..5b8176c687ed 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -19,6 +19,7 @@ import regex as re import torch import torch.nn.functional as F +from accelerate.utils import broadcast from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -921,6 +922,7 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device ) noise = noise.to(device=device, dtype=transformer_dtype) + noise = broadcast(noise, from_process=0) latents = alpha * latents + beta * noise # To fix the block artifact if self.config.is_distilled: From fe61e62dd6f7ca715f936f1dfbd518c8ade3230b Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 8 Mar 2026 09:13:27 +0000 Subject: [PATCH 2/3] refacotr --- .../helios/pipeline_helios_pyramid.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 5b8176c687ed..4748e5ba1d28 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -19,7 +19,6 @@ import regex as re import torch import torch.nn.functional as F -from accelerate.utils import broadcast from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -450,7 +449,14 @@ def sample_block_noise( width, patch_size: tuple[int, ...] = (1, 2, 2), device: torch.device | None = None, + generator: torch.Generator | None = None, ): + # NOTE: A generator must be provided to ensure correct and reproducible results. + # Creating a default generator here is a fallback only — without a fixed seed, + # the output will be non-deterministic and may produce incorrect results in CP context. + if generator is None: + generator = torch.Generator(device=device) + gamma = self.scheduler.config.gamma _, ph, pw = patch_size block_size = ph * pw @@ -459,13 +465,17 @@ def sample_block_noise( torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma ) - cov += torch.eye(block_size, device=device) * 1e-6 - dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) + cov += torch.eye(block_size, device=device) * 1e-8 + cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. + + L = torch.linalg.cholesky(cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + z = torch.randn(block_number, block_size, device=device, generator=generator) + noise = z @ L.T - noise = dist.sample((block_number,)) # [block number, block_size] noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise @property @@ -919,10 +929,9 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape noise = self.sample_block_noise( - batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device + batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device, generator ) noise = noise.to(device=device, dtype=transformer_dtype) - noise = broadcast(noise, from_process=0) latents = alpha * latents + beta * noise # To fix the block artifact if self.config.is_distilled: From 645930a7b47f10bc3d19059028ea26408e1b594b Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 8 Mar 2026 09:14:39 +0000 Subject: [PATCH 3/3] make style and quality --- .../pipelines/helios/pipeline_helios_pyramid.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 4748e5ba1d28..d8f317a9a6f1 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -929,7 +929,14 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape noise = self.sample_block_noise( - batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device, generator + batch_size, + channel, + num_frames, + pyramid_height, + pyramid_width, + patch_size, + device, + generator, ) noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact