diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 098660ec3f39..ea06f35a0343 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -194,6 +194,8 @@
title: Model accelerators and hardware
- isExpanded: false
sections:
+ - local: using-diffusers/helios
+ title: Helios
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -350,6 +352,8 @@
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
+ - local: api/models/helios_transformer3d
+ title: HeliosTransformer3DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -625,7 +629,6 @@
title: Image-to-image
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
-
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
@@ -674,6 +677,8 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
+ - local: api/pipelines/helios
+ title: Helios
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/hunyuan_video15
@@ -745,6 +750,10 @@
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
title: FlowMatchHeunDiscreteScheduler
+ - local: api/schedulers/helios_dmd
+ title: HeliosDMDScheduler
+ - local: api/schedulers/helios
+ title: HeliosScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index bbae6a9020af..db1ea884558f 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -23,6 +23,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
+- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
@@ -86,6 +87,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
+## HeliosLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin
+
## HunyuanVideoLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
diff --git a/docs/source/en/api/models/helios_transformer3d.md b/docs/source/en/api/models/helios_transformer3d.md
new file mode 100644
index 000000000000..5aa2826c32ec
--- /dev/null
+++ b/docs/source/en/api/models/helios_transformer3d.md
@@ -0,0 +1,35 @@
+
+
+# HeliosTransformer3DModel
+
+A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import HeliosTransformer3DModel
+
+# Best Quality
+transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16)
+# Intermediate Weight
+transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16)
+# Best Efficiency
+transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## HeliosTransformer3DModel
+
+[[autodoc]] HeliosTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md
new file mode 100644
index 000000000000..81559b24c071
--- /dev/null
+++ b/docs/source/en/api/pipelines/helios.md
@@ -0,0 +1,465 @@
+
+
+
+
+# Helios
+
+[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
+
+* We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).
+
+The following Helios models are supported in Diffusers:
+
+- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.
+- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.
+- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.
+
+> [!TIP]
+> Click on the Helios models in the right sidebar for more examples of video generation.
+
+### Optimizing Memory and Inference Speed
+
+The example below demonstrates how to generate a video from text optimized for memory or inference speed.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+The Helios model below requires ~19GB of VRAM.
+
+```py
+import torch
+from diffusers import AutoModel, HeliosPipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video
+
+vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
+
+# group-offloading
+pipeline = HeliosPipeline.from_pretrained(
+ "BestWishYsh/Helios-Base",
+ vae=vae,
+ torch_dtype=torch.bfloat16
+)
+pipeline.enable_group_offload(
+ onload_device=torch.device("cuda"),
+ offload_device=torch.device("cpu"),
+ offload_type="block_level",
+ num_blocks_per_group=1,
+ use_stream=True,
+ record_stream=True,
+)
+
+prompt = """
+A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
+and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
+a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
+allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
+of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
+the vivid colors of its surroundings. A close-up shot with dynamic movement.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=99,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
+```
+
+
+
+
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+```py
+import torch
+from diffusers import AutoModel, HeliosPipeline
+from diffusers.utils import export_to_video
+
+vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
+
+pipeline = HeliosPipeline.from_pretrained(
+ "BestWishYsh/Helios-Base",
+ vae=vae,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+# attention backend
+# pipeline.transformer.set_attention_backend("flash")
+pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs
+
+# torch.compile
+torch.backends.cudnn.benchmark = True
+pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
+pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
+pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
+
+prompt = """
+A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
+and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
+a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
+allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
+of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
+the vivid colors of its surroundings. A close-up shot with dynamic movement.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=99,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
+```
+
+
+
+
+
+### Generation with Helios-Base
+
+The example below demonstrates how to use Helios-Base to generate video based on text, image or video.
+
+
+
+
+```python
+import torch
+from diffusers import AutoModel, HeliosPipeline
+from diffusers.utils import export_to_video, load_video, load_image
+
+vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
+
+pipeline = HeliosPipeline.from_pretrained(
+ "BestWishYsh/Helios-Base",
+ vae=vae,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+# For Text-to-Video
+prompt = """
+A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
+and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
+a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
+allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
+of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
+the vivid colors of its surroundings. A close-up shot with dynamic movement.
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=99,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
+
+# For Image-to-Video
+prompt = """
+A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
+illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
+casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
+apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
+relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
+respect for nature’s might.
+"""
+image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=load_image(image_path).resize((640, 384)),
+ num_frames=99,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_base_i2v_output.mp4", fps=24)
+
+# For Video-to-Video
+prompt = """
+A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
+under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
+emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
+the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
+A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
+"""
+video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ video=load_video(video_path),
+ num_frames=99,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_base_v2v_output.mp4", fps=24)
+```
+
+
+
+
+
+### Generation with Helios-Mid
+
+The example below demonstrates how to use Helios-Mid to generate video based on text, image or video.
+
+
+
+
+```python
+import torch
+from diffusers import AutoModel, HeliosPyramidPipeline
+from diffusers.utils import export_to_video, load_video, load_image
+
+vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32)
+
+pipeline = HeliosPyramidPipeline.from_pretrained(
+ "BestWishYsh/Helios-Mid",
+ vae=vae,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+# For Text-to-Video
+prompt = """
+A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
+and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
+a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
+allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
+of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
+the vivid colors of its surroundings. A close-up shot with dynamic movement.
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=99,
+ pyramid_num_inference_steps_list=[20, 20, 20],
+ guidance_scale=5.0,
+ use_zero_init=True,
+ zero_steps=1,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24)
+
+# For Image-to-Video
+prompt = """
+A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
+illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
+casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
+apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
+relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
+respect for nature’s might.
+"""
+image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=load_image(image_path).resize((640, 384)),
+ num_frames=99,
+ pyramid_num_inference_steps_list=[20, 20, 20],
+ guidance_scale=5.0,
+ use_zero_init=True,
+ zero_steps=1,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24)
+
+# For Video-to-Video
+prompt = """
+A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
+under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
+emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
+the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
+A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
+"""
+video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ video=load_video(video_path),
+ num_frames=99,
+ pyramid_num_inference_steps_list=[20, 20, 20],
+ guidance_scale=5.0,
+ use_zero_init=True,
+ zero_steps=1,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24)
+```
+
+
+
+
+
+### Generation with Helios-Distilled
+
+The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.
+
+
+
+
+```python
+import torch
+from diffusers import AutoModel, HeliosPyramidPipeline
+from diffusers.utils import export_to_video, load_video, load_image
+
+vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32)
+
+pipeline = HeliosPyramidPipeline.from_pretrained(
+ "BestWishYsh/Helios-Distilled",
+ vae=vae,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+# For Text-to-Video
+prompt = """
+A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
+and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
+a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
+allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
+of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
+the vivid colors of its surroundings. A close-up shot with dynamic movement.
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=240,
+ pyramid_num_inference_steps_list=[2, 2, 2],
+ guidance_scale=1.0,
+ is_amplify_first_chunk=True,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24)
+
+# For Image-to-Video
+prompt = """
+A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
+illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
+casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
+apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
+relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
+respect for nature’s might.
+"""
+image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=load_image(image_path).resize((640, 384)),
+ num_frames=240,
+ pyramid_num_inference_steps_list=[2, 2, 2],
+ guidance_scale=1.0,
+ is_amplify_first_chunk=True,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24)
+
+# For Video-to-Video
+prompt = """
+A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
+under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
+emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
+the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
+A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
+"""
+video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ video=load_video(video_path),
+ num_frames=240,
+ pyramid_num_inference_steps_list=[2, 2, 2],
+ guidance_scale=1.0,
+ is_amplify_first_chunk=True,
+ generator=torch.Generator("cuda").manual_seed(42),
+).frames[0]
+export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24)
+```
+
+
+
+
+
+## HeliosPipeline
+
+[[autodoc]] HeliosPipeline
+
+ - all
+ - __call__
+
+## HeliosPyramidPipeline
+
+[[autodoc]] HeliosPyramidPipeline
+
+ - all
+ - __call__
+
+## HeliosPipelineOutput
+
+[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput
diff --git a/docs/source/en/api/schedulers/helios.md b/docs/source/en/api/schedulers/helios.md
new file mode 100644
index 000000000000..14c2be60bc89
--- /dev/null
+++ b/docs/source/en/api/schedulers/helios.md
@@ -0,0 +1,20 @@
+
+
+# HeliosScheduler
+
+`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
+
+## HeliosScheduler
+[[autodoc]] HeliosScheduler
+
+scheduling_helios
diff --git a/docs/source/en/api/schedulers/helios_dmd.md b/docs/source/en/api/schedulers/helios_dmd.md
new file mode 100644
index 000000000000..4f075e8a7dfc
--- /dev/null
+++ b/docs/source/en/api/schedulers/helios_dmd.md
@@ -0,0 +1,20 @@
+
+
+# HeliosDMDScheduler
+
+`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
+
+## HeliosDMDScheduler
+[[autodoc]] HeliosDMDScheduler
+
+scheduling_helios_dmd
diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md
index b6b04ddaf57e..96ece5b20c3a 100644
--- a/docs/source/en/using-diffusers/consisid.md
+++ b/docs/source/en/using-diffusers/consisid.md
@@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8)
Face Image
Video
- Description Description
diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md
new file mode 100644
index 000000000000..8106f1c568f8
--- /dev/null
+++ b/docs/source/en/using-diffusers/helios.md
@@ -0,0 +1,133 @@
+
+# Helios
+
+[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:
+
+- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.
+- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.
+- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.
+
+This guide will walk you through using Helios for use cases.
+
+## Load Model Checkpoints
+
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
+
+```python
+import torch
+from diffusers import HeliosPipeline, HeliosPyramidPipeline
+from huggingface_hub import snapshot_download
+
+# For Best Quality
+snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
+pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Intermediate Weight
+snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
+pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# For Best Efficiency
+snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
+pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+```
+
+## Text-to-Video Showcases
+
+
+
+ Prompt
+ Generated Video
+
+
+ A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
+
+
+
+
+
+
+
+
+ A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
+
+
+
+
+
+
+
+
+
+## Image-to-Video Showcases
+
+
+
+ Image
+ Prompt
+ Generated Video
+
+
+
+ A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
+
+
+
+
+
+
+
+
+
+ A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
+
+
+
+
+
+
+
+
+
+## Interactive-Video Showcases
+
+
+
+ Prompt
+ Generated Video
+
+
+ The prompt can be found here
+
+
+
+
+
+
+
+ The prompt can be found here
+
+
+
+
+
+
+
+
+## Resources
+
+Learn more about Helios with the following resources.
+- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
+- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 337d010fc74d..ab9eaf6ec7fb 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -132,6 +132,8 @@
sections:
- local: using-diffusers/consisid
title: ConsisID
+ - local: using-diffusers/helios
+ title: Helios
- title: Resources
isExpanded: false
diff --git a/docs/source/zh/community_projects.md b/docs/source/zh/community_projects.md
index 0440142452f1..ffa45f1e9bb0 100644
--- a/docs/source/zh/community_projects.md
+++ b/docs/source/zh/community_projects.md
@@ -26,6 +26,14 @@ http://www.apache.org/licenses/LICENSE-2.0
项目名称
描述
+
+ helios
+ Helios:比1.3B更低开销、更快且更强的14B的实时长视频生成模型
+
+
+ consisid
+ ConsisID:零样本身份保持的文本到视频生成模型
+
dream-textures
Stable Diffusion内置到Blender
diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md
new file mode 100644
index 000000000000..5c4faed2ca2a
--- /dev/null
+++ b/docs/source/zh/using-diffusers/helios.md
@@ -0,0 +1,134 @@
+
+# Helios
+
+[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时,拥有媲美强大基线模型的生成质量,并在统一架构下原生集成了文生视频(T2V)、图生视频(I2V)和视频生视频(V2V)任务。Helios 的主要特性包括:
+
+- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样),我们的模型即可生成高质量且高度连贯的分钟级视频。
+- 无需标准的加速技术(例如:KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化),作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。
+- 引入了多项优化方案,在降低显存消耗的同时,显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片(sharding)等基础设施,即可使用与图像模型相当的批大小(batch sizes)来训练 14B 的视频生成模型。
+
+本指南将引导您完成 Helios 在不同场景下的使用。
+
+## Load Model Checkpoints
+
+模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
+
+```python
+import torch
+from diffusers import HeliosPipeline, HeliosPyramidPipeline
+from huggingface_hub import snapshot_download
+
+# For Best Quality
+snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
+pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Intermediate Weight
+snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
+pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# For Best Efficiency
+snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
+pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+```
+
+## Text-to-Video Showcases
+
+
+
+ Prompt
+ Generated Video
+
+
+ A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
+
+
+
+
+
+
+
+
+ A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
+
+
+
+
+
+
+
+
+
+## Image-to-Video Showcases
+
+
+
+ Image
+ Prompt
+ Generated Video
+
+
+
+ A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
+
+
+
+
+
+
+
+
+
+ A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
+
+
+
+
+
+
+
+
+
+## Interactive-Video Showcases
+
+
+
+ Prompt
+ Generated Video
+
+
+ The prompt can be found here
+
+
+
+
+
+
+
+ The prompt can be found here
+
+
+
+
+
+
+
+
+## Resources
+
+通过以下资源了解有关 Helios 的更多信息:
+
+- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
+- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 1fc0914fe09e..1458164191df 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -227,6 +227,7 @@
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
+ "HeliosTransformer3DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -359,6 +360,8 @@
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"FlowMatchLCMScheduler",
+ "HeliosDMDScheduler",
+ "HeliosScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -515,6 +518,8 @@
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
+ "HeliosPipeline",
+ "HeliosPyramidPipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -994,6 +999,7 @@
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
+ HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1122,6 +1128,8 @@
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
FlowMatchLCMScheduler,
+ HeliosDMDScheduler,
+ HeliosScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
@@ -1257,6 +1265,8 @@
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
+ HeliosPipeline,
+ HeliosPyramidPipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index bdd4dbbcd4b5..ed0d2a07336f 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -78,6 +78,7 @@ def text_encoder_attn_modules(text_encoder):
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
+ "HeliosLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
@@ -118,6 +119,7 @@ def text_encoder_attn_modules(text_encoder):
CogView4LoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
+ HeliosLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 3423a88d3368..5d10f596f2e6 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -3440,6 +3440,207 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)
+class HeliosLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+ elif any(k.startswith("lora_unet_") for k in state_dict):
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
+ adapter_name: str | None = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: str | os.PathLike,
+ transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: dict | None = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: list[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: list[str] | None = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 80fb6a72869a..a96542c2a50c 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -51,6 +51,7 @@
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
+ "HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 96953afa4f4a..8b8d9c52659e 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -100,6 +100,7 @@
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
+ _import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -212,6 +213,7 @@
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
+ HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index d9d1b27a1e40..45157ee91808 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -28,6 +28,7 @@
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
+ from .transformer_helios import HeliosTransformer3DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py
new file mode 100644
index 000000000000..9f3ef047d98d
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_helios.py
@@ -0,0 +1,814 @@
+# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import lru_cache
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import apply_lora_scale, logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def pad_for_3d_conv(x, kernel_size):
+ b, c, t, h, w = x.shape
+ pt, ph, pw = kernel_size
+ pad_t = (pt - (t % pt)) % pt
+ pad_h = (ph - (h % ph)) % ph
+ pad_w = (pw - (w % pw)) % pw
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
+
+
+def center_down_sample_3d(x, kernel_size):
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
+
+
+def apply_rotary_emb_transposed(
+ hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+):
+ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
+ out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
+ return out.type_as(hidden_states)
+
+
+def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if not attn.is_cross_attention:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+class HeliosOutputNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
+ super().__init__()
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+ self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int):
+ temb = temb[:, -original_context_length:, :]
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device)
+ hidden_states = hidden_states[:, -original_context_length:, :]
+ hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ return hidden_states
+
+
+class HeliosAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "HeliosAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ original_context_length: int = None,
+ ) -> torch.Tensor:
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+ query = apply_rotary_emb_transposed(query, rotary_emb)
+ key = apply_rotary_emb_transposed(key, rotary_emb)
+
+ if not attn.is_cross_attention and attn.is_amplify_history:
+ history_seq_len = hidden_states.shape[1] - original_context_length
+
+ if history_seq_len > 0:
+ scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0)
+ if attn.history_scale_mode == "per_head":
+ scale_key = scale_key.view(1, 1, -1, 1)
+ key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ # Reference: https://github.com/huggingface/diffusers/pull/12909
+ parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class HeliosAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = HeliosAttnProcessor
+ _available_processors = [HeliosAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: int | None = None,
+ cross_attention_dim_head: int | None = None,
+ processor=None,
+ is_cross_attention=None,
+ is_amplify_history=False,
+ history_scale_mode="per_head", # [scalar, per_head]
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ if is_cross_attention is not None:
+ self.is_cross_attention = is_cross_attention
+ else:
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ self.is_amplify_history = is_amplify_history
+ if is_amplify_history:
+ if history_scale_mode == "scalar":
+ self.history_key_scale = nn.Parameter(torch.ones(1))
+ elif history_scale_mode == "per_head":
+ self.history_key_scale = nn.Parameter(torch.ones(heads))
+ else:
+ raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}")
+ self.history_scale_mode = history_scale_mode
+ self.max_scale = 10.0
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if not self.is_cross_attention:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ original_context_length: int = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ rotary_emb,
+ original_context_length,
+ **kwargs,
+ )
+
+
+class HeliosTimeTextEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ is_return_encoder_hidden_states: bool = True,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ if encoder_hidden_states is not None and is_return_encoder_hidden_states:
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+
+ return temb, timestep_proj, encoder_hidden_states
+
+
+class HeliosRotaryPosEmbed(nn.Module):
+ def __init__(self, rope_dim, theta):
+ super().__init__()
+ self.DT, self.DY, self.DX = rope_dim
+ self.theta = theta
+ self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False)
+ self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False)
+ self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False)
+
+ def _get_freqs_base(self, dim):
+ return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim))
+
+ @torch.no_grad()
+ def get_frequency_batched(self, freqs_base, pos):
+ freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos)
+ freqs = freqs.repeat_interleave(2, dim=0)
+ return freqs.cos(), freqs.sin()
+
+ @torch.no_grad()
+ @lru_cache(maxsize=32)
+ def _get_spatial_meshgrid(self, height, width, device_str):
+ device = torch.device(device_str)
+ grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
+ grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)
+ grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij")
+ return grid_y, grid_x
+
+ @torch.no_grad()
+ def forward(self, frame_indices, height, width, device):
+ batch_size = frame_indices.shape[0]
+ num_frames = frame_indices.shape[1]
+
+ frame_indices = frame_indices.to(device=device, dtype=torch.float32)
+ grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))
+
+ grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width)
+ grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)
+ grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)
+
+ freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t)
+ freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch)
+ freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch)
+
+ result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0)
+
+ return result.permute(1, 0, 2, 3, 4)
+
+
+@maybe_allow_in_graph
+class HeliosTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: int | None = None,
+ guidance_cross_attn: bool = False,
+ is_amplify_history: bool = False,
+ history_scale_mode: str = "per_head", # [scalar, per_head]
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = HeliosAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=HeliosAttnProcessor(),
+ is_amplify_history=is_amplify_history,
+ history_scale_mode=history_scale_mode,
+ )
+
+ # 2. Cross-attention
+ self.attn2 = HeliosAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=HeliosAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ # 4. Guidance cross-attention
+ self.guidance_cross_attn = guidance_cross_attn
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ original_context_length: int = None,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(
+ norm_hidden_states,
+ None,
+ None,
+ rotary_emb,
+ original_context_length,
+ )
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ if self.guidance_cross_attn:
+ history_seq_len = hidden_states.shape[1] - original_context_length
+
+ history_hidden_states, hidden_states = torch.split(
+ hidden_states, [history_seq_len, original_context_length], dim=1
+ )
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states,
+ None,
+ None,
+ original_context_length,
+ )
+ hidden_states = hidden_states + attn_output
+ hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1)
+ else:
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states,
+ None,
+ None,
+ original_context_length,
+ )
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class HeliosTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Helios model.
+
+ Args:
+ patch_size (`tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = [
+ "patch_embedding",
+ "patch_short",
+ "patch_mid",
+ "patch_long",
+ "condition_embedder",
+ "norm",
+ ]
+ _no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"]
+ _keep_in_fp32_modules = [
+ "time_embedder",
+ "scale_shift_table",
+ "norm1",
+ "norm2",
+ "norm3",
+ "history_key_scale",
+ ]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["HeliosTransformerBlock"]
+ _cp_plan = {
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
+ "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: tuple[int, ...] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: str | None = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ added_kv_proj_dim: int | None = None,
+ rope_dim: tuple[int, ...] = (44, 42, 42),
+ rope_theta: float = 10000.0,
+ guidance_cross_attn: bool = True,
+ zero_history_timestep: bool = True,
+ has_multi_term_memory_patch: bool = True,
+ is_amplify_history: bool = False,
+ history_scale_mode: str = "per_head", # [scalar, per_head]
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Initial Multi Term Memory Patch
+ self.zero_history_timestep = zero_history_timestep
+ if has_multi_term_memory_patch:
+ self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.patch_mid = nn.Conv3d(
+ in_channels,
+ inner_dim,
+ kernel_size=tuple(2 * p for p in patch_size),
+ stride=tuple(2 * p for p in patch_size),
+ )
+ self.patch_long = nn.Conv3d(
+ in_channels,
+ inner_dim,
+ kernel_size=tuple(4 * p for p in patch_size),
+ stride=tuple(4 * p for p in patch_size),
+ )
+
+ # 3. Condition embeddings
+ self.condition_embedder = HeliosTimeTextEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ )
+
+ # 4. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ HeliosTransformerBlock(
+ inner_dim,
+ ffn_dim,
+ num_attention_heads,
+ qk_norm,
+ cross_attn_norm,
+ eps,
+ added_kv_proj_dim,
+ guidance_cross_attn=guidance_cross_attn,
+ is_amplify_history=is_amplify_history,
+ history_scale_mode=history_scale_mode,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 5. Output norm & projection
+ self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+
+ self.gradient_checkpointing = False
+
+ @apply_lora_scale("attention_kwargs")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ # ------------ Stage 1 ------------
+ indices_hidden_states=None,
+ indices_latents_history_short=None,
+ indices_latents_history_mid=None,
+ indices_latents_history_long=None,
+ latents_history_short=None,
+ latents_history_mid=None,
+ latents_history_long=None,
+ return_dict: bool = True,
+ attention_kwargs: dict[str, Any] | None = None,
+ ) -> torch.Tensor | dict[str, torch.Tensor]:
+ # 1. Input
+ batch_size = hidden_states.shape[0]
+ p_t, p_h, p_w = self.config.patch_size
+
+ # 2. Process noisy latents
+ hidden_states = self.patch_embedding(hidden_states)
+ _, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape
+
+ if indices_hidden_states is None:
+ indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1)
+
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ rotary_emb = self.rope(
+ frame_indices=indices_hidden_states,
+ height=post_patch_height,
+ width=post_patch_width,
+ device=hidden_states.device,
+ )
+ rotary_emb = rotary_emb.flatten(2).transpose(1, 2)
+ original_context_length = hidden_states.shape[1]
+
+ # 3. Process short history latents
+ if latents_history_short is not None and indices_latents_history_short is not None:
+ latents_history_short = self.patch_short(latents_history_short)
+ _, _, _, H1, W1 = latents_history_short.shape
+ latents_history_short = latents_history_short.flatten(2).transpose(1, 2)
+
+ rotary_emb_history_short = self.rope(
+ frame_indices=indices_latents_history_short,
+ height=H1,
+ width=W1,
+ device=latents_history_short.device,
+ )
+ rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)
+ rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)
+
+ # 4. Process mid history latents
+ if latents_history_mid is not None and indices_latents_history_mid is not None:
+ latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))
+ latents_history_mid = self.patch_mid(latents_history_mid)
+ latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)
+
+ rotary_emb_history_mid = self.rope(
+ frame_indices=indices_latents_history_mid,
+ height=H1,
+ width=W1,
+ device=latents_history_mid.device,
+ )
+ rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))
+ rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2))
+ rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)
+ rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)
+
+ # 5. Process long history latents
+ if latents_history_long is not None and indices_latents_history_long is not None:
+ latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))
+ latents_history_long = self.patch_long(latents_history_long)
+ latents_history_long = latents_history_long.flatten(2).transpose(1, 2)
+
+ rotary_emb_history_long = self.rope(
+ frame_indices=indices_latents_history_long,
+ height=H1,
+ width=W1,
+ device=latents_history_long.device,
+ )
+ rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4))
+ rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4))
+ rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)
+
+ hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)
+ rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)
+
+ history_context_length = hidden_states.shape[1] - original_context_length
+
+ if indices_hidden_states is not None and self.zero_history_timestep:
+ timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device)
+ temb_t0, timestep_proj_t0, _ = self.condition_embedder(
+ timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False
+ )
+ temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1)
+ timestep_proj_t0 = (
+ timestep_proj_t0.unflatten(-1, (6, -1))
+ .view(1, 6, 1, -1)
+ .expand(batch_size, -1, history_context_length, -1)
+ )
+
+ temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
+ timestep_proj = timestep_proj.unflatten(-1, (6, -1))
+
+ if indices_hidden_states is not None and not self.zero_history_timestep:
+ main_repeat_size = hidden_states.shape[1]
+ else:
+ main_repeat_size = original_context_length
+ temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)
+ timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1)
+
+ if indices_hidden_states is not None and self.zero_history_timestep:
+ temb = torch.cat([temb_t0, temb], dim=1)
+ timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)
+
+ if timestep_proj.ndim == 4:
+ timestep_proj = timestep_proj.permute(0, 2, 1, 3)
+
+ # 6. Transformer blocks
+ hidden_states = hidden_states.contiguous()
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+ rotary_emb = rotary_emb.contiguous()
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ original_context_length,
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ original_context_length,
+ )
+
+ # 7. Normalization
+ hidden_states = self.norm_out(hidden_states, temb, original_context_length)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 8. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 638598051d64..08cb28a6237a 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -237,6 +237,7 @@
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
+ _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
@@ -667,6 +668,7 @@
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
+ from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 7247ca0d161c..72151dc40a53 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -54,6 +54,7 @@
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
+from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -174,6 +175,8 @@
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
+ ("helios", HeliosPipeline),
+ ("helios-pyramid", HeliosPyramidPipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
diff --git a/src/diffusers/pipelines/helios/__init__.py b/src/diffusers/pipelines/helios/__init__.py
new file mode 100644
index 000000000000..ae08f5997279
--- /dev/null
+++ b/src/diffusers/pipelines/helios/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_helios"] = ["HeliosPipeline"]
+ _import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_helios import HeliosPipeline
+ from .pipeline_helios_pyramid import HeliosPyramidPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py
new file mode 100644
index 000000000000..87a8600badab
--- /dev/null
+++ b/src/diffusers/pipelines/helios/pipeline_helios.py
@@ -0,0 +1,916 @@
+# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable
+
+import numpy as np
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import HeliosLoraLoaderMixin
+from ...models import AutoencoderKLWan, HeliosTransformer3DModel
+from ...schedulers import HeliosScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HeliosPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLWan, HeliosPipeline
+
+ >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled
+ >>> model_id = "BestWishYsh/Helios-Base"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=384,
+ ... width=640,
+ ... num_frames=132,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video / image-to-video / video-to-video generation using Helios.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`HeliosTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`HeliosScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: HeliosScheduler,
+ transformer: HeliosTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: str | list[str] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_inputs.attention_mask.bool()
+
+ def encode_prompt(
+ self,
+ prompt: str | list[str],
+ negative_prompt: str | list[str] | None = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, _ = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, _ = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ image=None,
+ video=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if image is not None and video is not None:
+ raise ValueError("image and video cannot be provided simultaneously")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 384,
+ width: int = 640,
+ num_frames: int = 33,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_image_latents(
+ self,
+ image: torch.Tensor,
+ latents_mean: torch.Tensor,
+ latents_std: torch.Tensor,
+ num_latent_frames_per_chunk: int,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ fake_latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ if latents is None:
+ image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
+ latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+ latents = (latents - latents_mean) * latents_std
+ if fake_latents is None:
+ min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype)
+ fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator)
+ fake_latents_full = (fake_latents_full - latents_mean) * latents_std
+ fake_latents = fake_latents_full[:, :, -1:, :, :]
+ return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype)
+
+ def prepare_video_latents(
+ self,
+ video: torch.Tensor,
+ latents_mean: torch.Tensor,
+ latents_std: torch.Tensor,
+ num_latent_frames_per_chunk: int,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ video = video.to(device=device, dtype=self.vae.dtype)
+ if latents is None:
+ num_frames = video.shape[2]
+ min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ num_chunks = num_frames // min_frames
+ if num_chunks == 0:
+ raise ValueError(
+ f"Video must have at least {min_frames} frames "
+ f"(got {num_frames} frames). "
+ f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}"
+ )
+ total_valid_frames = num_chunks * min_frames
+ start_frame = num_frames - total_valid_frames
+
+ first_frame = video[:, :, 0:1, :, :]
+ first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator)
+ first_frame_latent = (first_frame_latent - latents_mean) * latents_std
+
+ latents_chunks = []
+ for i in range(num_chunks):
+ chunk_start = start_frame + i * min_frames
+ chunk_end = chunk_start + min_frames
+ video_chunk = video[:, :, chunk_start:chunk_end, :, :]
+ chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator)
+ chunk_latents = (chunk_latents - latents_mean) * latents_std
+ latents_chunks.append(chunk_latents)
+ latents = torch.cat(latents_chunks, dim=2)
+ return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: str | list[str] = None,
+ negative_prompt: str | list[str] = None,
+ height: int = 384,
+ width: int = 640,
+ num_frames: int = 132,
+ num_inference_steps: int = 50,
+ sigmas: list[float] = None,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: int | None = 1,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ output_type: str | None = "np",
+ return_dict: bool = True,
+ attention_kwargs: dict[str, Any] | None = None,
+ callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
+ max_sequence_length: int = 512,
+ # ------------ I2V ------------
+ image: PipelineImageInput | None = None,
+ image_latents: torch.Tensor | None = None,
+ fake_image_latents: torch.Tensor | None = None,
+ add_noise_to_image_latents: bool = True,
+ image_noise_sigma_min: float = 0.111,
+ image_noise_sigma_max: float = 0.135,
+ # ------------ V2V ------------
+ video: PipelineImageInput | None = None,
+ video_latents: torch.Tensor | None = None,
+ add_noise_to_video_latents: bool = True,
+ video_noise_sigma_min: float = 0.111,
+ video_noise_sigma_max: float = 0.135,
+ # ------------ Stage 1 ------------
+ history_sizes: list = [16, 2, 1],
+ num_latent_frames_per_chunk: int = 9,
+ keep_first_frame: bool = True,
+ is_skip_first_chunk: bool = False,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `384`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `640`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `132`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`list`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~HeliosPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ history_sizes = sorted(history_sizes, reverse=True) # From big to small
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ image,
+ video,
+ )
+
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+ vae_dtype = self.vae.dtype
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(device, self.vae.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, self.vae.dtype
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare image or video
+ if image is not None:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image_latents, fake_image_latents = self.prepare_image_latents(
+ image,
+ latents_mean=latents_mean,
+ latents_std=latents_std,
+ num_latent_frames_per_chunk=num_latent_frames_per_chunk,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=image_latents,
+ fake_latents=fake_image_latents,
+ )
+
+ if image_latents is not None and add_noise_to_image_latents:
+ image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ + image_noise_sigma_min
+ )
+ image_latents = (
+ image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ + (1 - image_noise_sigma) * image_latents
+ )
+ fake_image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min)
+ + video_noise_sigma_min
+ )
+ fake_image_latents = (
+ fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device)
+ + (1 - fake_image_noise_sigma) * fake_image_latents
+ )
+
+ if video is not None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ image_latents, video_latents = self.prepare_video_latents(
+ video,
+ latents_mean=latents_mean,
+ latents_std=latents_std,
+ num_latent_frames_per_chunk=num_latent_frames_per_chunk,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=video_latents,
+ )
+
+ if video_latents is not None and add_noise_to_video_latents:
+ image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ + image_noise_sigma_min
+ )
+ image_latents = (
+ image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ + (1 - image_noise_sigma) * image_latents
+ )
+
+ noisy_latents_chunks = []
+ num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk
+ for i in range(num_latent_chunks):
+ chunk_start = i * num_latent_frames_per_chunk
+ chunk_end = chunk_start + num_latent_frames_per_chunk
+ latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :]
+
+ chunk_frames = latent_chunk.shape[2]
+ frame_sigmas = (
+ torch.rand(chunk_frames, device=device, generator=generator)
+ * (video_noise_sigma_max - video_noise_sigma_min)
+ + video_noise_sigma_min
+ )
+ frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1)
+
+ noisy_chunk = (
+ frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device)
+ + (1 - frame_sigmas) * latent_chunk
+ )
+ noisy_latents_chunks.append(noisy_chunk)
+ video_latents = torch.cat(noisy_latents_chunks, dim=2)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
+ num_history_latent_frames = sum(history_sizes)
+ history_video = None
+ total_generated_latent_frames = 0
+
+ if not keep_first_frame:
+ history_sizes[-1] = history_sizes[-1] + 1
+ history_latents = torch.zeros(
+ batch_size,
+ num_channels_latents,
+ num_history_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ device=device,
+ dtype=torch.float32,
+ )
+ if fake_image_latents is not None:
+ history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2)
+ total_generated_latent_frames += 1
+ if video_latents is not None:
+ history_frames = history_latents.shape[2]
+ video_frames = video_latents.shape[2]
+ if video_frames < history_frames:
+ keep_frames = history_frames - video_frames
+ history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2)
+ else:
+ history_latents = video_latents
+ total_generated_latent_frames += video_latents.shape[2]
+
+ if keep_first_frame:
+ indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk]))
+ (
+ indices_prefix,
+ indices_latents_history_long,
+ indices_latents_history_mid,
+ indices_latents_history_1x,
+ indices_hidden_states,
+ ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0)
+ indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
+ else:
+ indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk]))
+ (
+ indices_latents_history_long,
+ indices_latents_history_mid,
+ indices_latents_history_short,
+ indices_hidden_states,
+ ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0)
+ indices_hidden_states = indices_hidden_states.unsqueeze(0)
+ indices_latents_history_short = indices_latents_history_short.unsqueeze(0)
+ indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)
+ indices_latents_history_long = indices_latents_history_long.unsqueeze(0)
+
+ # 6. Denoising loop
+ patch_size = self.transformer.config.patch_size
+ image_seq_len = (
+ num_latent_frames_per_chunk
+ * (height // self.vae_scale_factor_spatial)
+ * (width // self.vae_scale_factor_spatial)
+ // (patch_size[0] * patch_size[1] * patch_size[2])
+ )
+ sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+
+ for k in range(num_latent_chunk):
+ is_first_chunk = k == 0
+ is_second_chunk = k == 1
+ if keep_first_frame:
+ latents_history_long, latents_history_mid, latents_history_1x = history_latents[
+ :, :, -num_history_latent_frames:
+ ].split(history_sizes, dim=2)
+ if image_latents is None and is_first_chunk:
+ latents_prefix = torch.zeros(
+ (
+ batch_size,
+ num_channels_latents,
+ 1,
+ latents_history_1x.shape[-2],
+ latents_history_1x.shape[-1],
+ ),
+ device=device,
+ dtype=latents_history_1x.dtype,
+ )
+ else:
+ latents_prefix = image_latents
+ latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2)
+ else:
+ latents_history_long, latents_history_mid, latents_history_short = history_latents[
+ :, :, -num_history_latent_frames:
+ ].split(history_sizes, dim=2)
+
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ window_num_frames,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=None,
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu)
+ timesteps = self.scheduler.timesteps
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0])
+
+ latent_model_input = latents.to(transformer_dtype)
+ latents_history_short = latents_history_short.to(transformer_dtype)
+ latents_history_mid = latents_history_mid.to(transformer_dtype)
+ latents_history_long = latents_history_long.to(transformer_dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ indices_hidden_states=indices_hidden_states,
+ indices_latents_history_short=indices_latents_history_short,
+ indices_latents_history_mid=indices_latents_history_mid,
+ indices_latents_history_long=indices_latents_history_long,
+ latents_history_short=latents_history_short,
+ latents_history_mid=latents_history_mid,
+ latents_history_long=latents_history_long,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ indices_hidden_states=indices_hidden_states,
+ indices_latents_history_short=indices_latents_history_short,
+ indices_latents_history_mid=indices_latents_history_mid,
+ indices_latents_history_long=indices_latents_history_long,
+ latents_history_short=latents_history_short,
+ latents_history_mid=latents_history_mid,
+ latents_history_long=latents_history_long,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if keep_first_frame and (
+ (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk)
+ ):
+ image_latents = latents[:, :, 0:1, :, :]
+
+ total_generated_latent_frames += latents.shape[2]
+ history_latents = torch.cat([history_latents, latents], dim=2)
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
+ current_latents = (
+ real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std
+ + latents_mean
+ )
+ current_video = self.vae.decode(current_latents, return_dict=False)[0]
+
+ if history_video is None:
+ history_video = current_video
+ else:
+ history_video = torch.cat([history_video, current_video], dim=2)
+
+ self._current_timestep = None
+
+ if output_type != "latent":
+ generated_frames = history_video.size(2)
+ generated_frames = (
+ generated_frames - 1
+ ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ history_video = history_video[:, :, :generated_frames]
+ video = self.video_processor.postprocess_video(history_video, output_type=output_type)
+ else:
+ video = real_history_latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HeliosPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py
new file mode 100644
index 000000000000..40c1d65825ff
--- /dev/null
+++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py
@@ -0,0 +1,1065 @@
+# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+from typing import Any, Callable
+
+import regex as re
+import torch
+import torch.nn.functional as F
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import HeliosLoraLoaderMixin
+from ...models import AutoencoderKLWan, HeliosTransformer3DModel
+from ...schedulers import HeliosDMDScheduler, HeliosScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HeliosPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLWan, HeliosPyramidPipeline
+
+ >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled
+ >>> model_id = "BestWishYsh/Helios-Base"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = HeliosPyramidPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=384,
+ ... width=640,
+ ... num_frames=132,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+def optimized_scale(positive_flat, negative_flat):
+ # Calculate dot production
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
+ # Squared norm of uncondition
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ st_star = dot_product / squared_norm
+ return st_star
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video / image-to-video / video-to-video generation using Helios.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`HeliosTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`HeliosScheduler`, `HeliosDMDScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: HeliosScheduler | HeliosDMDScheduler,
+ transformer: HeliosTransformer3DModel,
+ is_cfg_zero_star: bool = False,
+ is_distilled: bool = False,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.register_to_config(is_cfg_zero_star=is_cfg_zero_star)
+ self.register_to_config(is_distilled=is_distilled)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.helios.pipeline_helios.HeliosPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: str | list[str] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, text_inputs.attention_mask.bool()
+
+ def encode_prompt(
+ self,
+ prompt: str | list[str],
+ negative_prompt: str | list[str] | None = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ max_sequence_length: int = 226,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, _ = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, _ = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ image=None,
+ video=None,
+ guidance_scale=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if image is not None and video is not None:
+ raise ValueError("image and video cannot be provided simultaneously")
+
+ if guidance_scale > 1.0 and self.config.is_distilled:
+ logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 384,
+ width: int = 640,
+ num_frames: int = 33,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_image_latents(
+ self,
+ image: torch.Tensor,
+ latents_mean: torch.Tensor,
+ latents_std: torch.Tensor,
+ num_latent_frames_per_chunk: int,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ fake_latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ if latents is None:
+ image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
+ latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+ latents = (latents - latents_mean) * latents_std
+ if fake_latents is None:
+ min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype)
+ fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator)
+ fake_latents_full = (fake_latents_full - latents_mean) * latents_std
+ fake_latents = fake_latents_full[:, :, -1:, :, :]
+ return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype)
+
+ def prepare_video_latents(
+ self,
+ video: torch.Tensor,
+ latents_mean: torch.Tensor,
+ latents_std: torch.Tensor,
+ num_latent_frames_per_chunk: int,
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ video = video.to(device=device, dtype=self.vae.dtype)
+ if latents is None:
+ num_frames = video.shape[2]
+ min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ num_chunks = num_frames // min_frames
+ if num_chunks == 0:
+ raise ValueError(
+ f"Video must have at least {min_frames} frames "
+ f"(got {num_frames} frames). "
+ f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}"
+ )
+ total_valid_frames = num_chunks * min_frames
+ start_frame = num_frames - total_valid_frames
+
+ first_frame = video[:, :, 0:1, :, :]
+ first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator)
+ first_frame_latent = (first_frame_latent - latents_mean) * latents_std
+
+ latents_chunks = []
+ for i in range(num_chunks):
+ chunk_start = start_frame + i * min_frames
+ chunk_end = chunk_start + min_frames
+ video_chunk = video[:, :, chunk_start:chunk_end, :, :]
+ chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator)
+ chunk_latents = (chunk_latents - latents_mean) * latents_std
+ latents_chunks.append(chunk_latents)
+ latents = torch.cat(latents_chunks, dim=2)
+ return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype)
+
+ def sample_block_noise(
+ self,
+ batch_size,
+ channel,
+ num_frames,
+ height,
+ width,
+ patch_size: tuple[int, ...] = (1, 2, 2),
+ device: torch.device | None = None,
+ ):
+ gamma = self.scheduler.config.gamma
+ _, ph, pw = patch_size
+ block_size = ph * pw
+
+ cov = (
+ torch.eye(block_size, device=device) * (1 + gamma)
+ - torch.ones(block_size, block_size, device=device) * gamma
+ )
+ cov += torch.eye(block_size, device=device) * 1e-6
+ dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov)
+ block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
+
+ noise = dist.sample((block_number,)) # [block number, block_size]
+ noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw)
+ noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width)
+ return noise
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: str | list[str] = None,
+ negative_prompt: str | list[str] = None,
+ height: int = 384,
+ width: int = 640,
+ num_frames: int = 132,
+ sigmas: list[float] = None,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: int | None = 1,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ output_type: str | None = "np",
+ return_dict: bool = True,
+ attention_kwargs: dict[str, Any] | None = None,
+ callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
+ max_sequence_length: int = 512,
+ # ------------ I2V ------------
+ image: PipelineImageInput | None = None,
+ image_latents: torch.Tensor | None = None,
+ fake_image_latents: torch.Tensor | None = None,
+ add_noise_to_image_latents: bool = True,
+ image_noise_sigma_min: float = 0.111,
+ image_noise_sigma_max: float = 0.135,
+ # ------------ V2V ------------
+ video: PipelineImageInput | None = None,
+ video_latents: torch.Tensor | None = None,
+ add_noise_to_video_latents: bool = True,
+ video_noise_sigma_min: float = 0.111,
+ video_noise_sigma_max: float = 0.135,
+ # ------------ Stage 1 ------------
+ history_sizes: list = [16, 2, 1],
+ num_latent_frames_per_chunk: int = 9,
+ keep_first_frame: bool = True,
+ is_skip_first_chunk: bool = False,
+ # ------------ Stage 2 ------------
+ pyramid_num_inference_steps_list: list = [10, 10, 10],
+ # ------------ CFG Zero ------------
+ use_zero_init: bool | None = True,
+ zero_steps: int | None = 1,
+ # ------------ DMD ------------
+ is_amplify_first_chunk: bool = False,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `list[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `384`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `640`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `132`):
+ The number of frames in the generated video.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`list`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~HeliosPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ history_sizes = sorted(history_sizes, reverse=True) # From big to small
+ pyramid_num_stages = len(pyramid_num_inference_steps_list)
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ image,
+ video,
+ guidance_scale,
+ )
+
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+ vae_dtype = self.vae.dtype
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(device, self.vae.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, self.vae.dtype
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare image or video
+ if image is not None:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image_latents, fake_image_latents = self.prepare_image_latents(
+ image,
+ latents_mean=latents_mean,
+ latents_std=latents_std,
+ num_latent_frames_per_chunk=num_latent_frames_per_chunk,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=image_latents,
+ fake_latents=fake_image_latents,
+ )
+
+ if image_latents is not None and add_noise_to_image_latents:
+ image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ + image_noise_sigma_min
+ )
+ image_latents = (
+ image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ + (1 - image_noise_sigma) * image_latents
+ )
+ fake_image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min)
+ + video_noise_sigma_min
+ )
+ fake_image_latents = (
+ fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device)
+ + (1 - fake_image_noise_sigma) * fake_image_latents
+ )
+
+ if video is not None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ image_latents, video_latents = self.prepare_video_latents(
+ video,
+ latents_mean=latents_mean,
+ latents_std=latents_std,
+ num_latent_frames_per_chunk=num_latent_frames_per_chunk,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=video_latents,
+ )
+
+ if video_latents is not None and add_noise_to_video_latents:
+ image_noise_sigma = (
+ torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ + image_noise_sigma_min
+ )
+ image_latents = (
+ image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ + (1 - image_noise_sigma) * image_latents
+ )
+
+ noisy_latents_chunks = []
+ num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk
+ for i in range(num_latent_chunks):
+ chunk_start = i * num_latent_frames_per_chunk
+ chunk_end = chunk_start + num_latent_frames_per_chunk
+ latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :]
+
+ chunk_frames = latent_chunk.shape[2]
+ frame_sigmas = (
+ torch.rand(chunk_frames, device=device, generator=generator)
+ * (video_noise_sigma_max - video_noise_sigma_min)
+ + video_noise_sigma_min
+ )
+ frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1)
+
+ noisy_chunk = (
+ frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device)
+ + (1 - frame_sigmas) * latent_chunk
+ )
+ noisy_latents_chunks.append(noisy_chunk)
+ video_latents = torch.cat(noisy_latents_chunks, dim=2)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
+ num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
+ num_history_latent_frames = sum(history_sizes)
+ history_video = None
+ total_generated_latent_frames = 0
+
+ if not keep_first_frame:
+ history_sizes[-1] = history_sizes[-1] + 1
+ history_latents = torch.zeros(
+ batch_size,
+ num_channels_latents,
+ num_history_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ device=device,
+ dtype=torch.float32,
+ )
+ if fake_image_latents is not None:
+ history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2)
+ total_generated_latent_frames += 1
+ if video_latents is not None:
+ history_frames = history_latents.shape[2]
+ video_frames = video_latents.shape[2]
+ if video_frames < history_frames:
+ keep_frames = history_frames - video_frames
+ history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2)
+ else:
+ history_latents = video_latents
+ total_generated_latent_frames += video_latents.shape[2]
+
+ if keep_first_frame:
+ indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk]))
+ (
+ indices_prefix,
+ indices_latents_history_long,
+ indices_latents_history_mid,
+ indices_latents_history_1x,
+ indices_hidden_states,
+ ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0)
+ indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
+ else:
+ indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk]))
+ (
+ indices_latents_history_long,
+ indices_latents_history_mid,
+ indices_latents_history_short,
+ indices_hidden_states,
+ ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0)
+ indices_hidden_states = indices_hidden_states.unsqueeze(0)
+ indices_latents_history_short = indices_latents_history_short.unsqueeze(0)
+ indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)
+ indices_latents_history_long = indices_latents_history_long.unsqueeze(0)
+
+ # 6. Denoising loop
+ for k in range(num_latent_chunk):
+ is_first_chunk = k == 0
+ is_second_chunk = k == 1
+ if keep_first_frame:
+ latents_history_long, latents_history_mid, latents_history_1x = history_latents[
+ :, :, -num_history_latent_frames:
+ ].split(history_sizes, dim=2)
+ if image_latents is None and is_first_chunk:
+ latents_prefix = torch.zeros(
+ (
+ batch_size,
+ num_channels_latents,
+ 1,
+ latents_history_1x.shape[-2],
+ latents_history_1x.shape[-1],
+ ),
+ device=device,
+ dtype=latents_history_1x.dtype,
+ )
+ else:
+ latents_prefix = image_latents
+ latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2)
+ else:
+ latents_history_long, latents_history_mid, latents_history_short = history_latents[
+ :, :, -num_history_latent_frames:
+ ].split(history_sizes, dim=2)
+
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ window_num_frames,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=None,
+ )
+
+ num_inference_steps = (
+ sum(pyramid_num_inference_steps_list) * 2
+ if is_amplify_first_chunk and self.config.is_distilled and is_first_chunk
+ else sum(pyramid_num_inference_steps_list)
+ )
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ _, _, _, pyramid_height, pyramid_width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width
+ )
+ for _ in range(pyramid_num_stages - 1):
+ pyramid_height //= 2
+ pyramid_width //= 2
+ latents = (
+ F.interpolate(
+ latents,
+ size=(pyramid_height, pyramid_width),
+ mode="bilinear",
+ )
+ * 2
+ )
+ latents = latents.reshape(
+ batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width
+ ).permute(0, 2, 1, 3, 4)
+
+ start_point_list = None
+ if self.config.is_distilled:
+ start_point_list = [latents]
+
+ for stage_idx in range(pyramid_num_stages):
+ patch_size = self.transformer.config.patch_size
+ image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // (
+ patch_size[0] * patch_size[1] * patch_size[2]
+ )
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ self.scheduler.set_timesteps(
+ pyramid_num_inference_steps_list[stage_idx],
+ stage_idx,
+ device=device,
+ mu=mu,
+ is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk,
+ )
+ timesteps = self.scheduler.timesteps
+ num_warmup_steps = 0
+ self._num_timesteps = len(timesteps)
+
+ if stage_idx > 0:
+ pyramid_height *= 2
+ pyramid_width *= 2
+ num_frames = latents.shape[2]
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_latent_frames_per_chunk,
+ num_channels_latents,
+ pyramid_height // 2,
+ pyramid_width // 2,
+ )
+ latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest")
+ latents = latents.reshape(
+ batch_size,
+ num_latent_frames_per_chunk,
+ num_channels_latents,
+ pyramid_height,
+ pyramid_width,
+ ).permute(0, 2, 1, 3, 4)
+ # Fix the stage
+ ori_sigma = 1 - self.scheduler.ori_start_sigmas[stage_idx] # the original coeff of signal
+ gamma = self.scheduler.config.gamma
+ alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
+ beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
+
+ batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape
+ noise = self.sample_block_noise(
+ batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device
+ )
+ noise = noise.to(device=device, dtype=transformer_dtype)
+ latents = alpha * latents + beta * noise # To fix the block artifact
+
+ if self.config.is_distilled:
+ start_point_list.append(latents)
+
+ for i, t in enumerate(timesteps):
+ timestep = t.expand(latents.shape[0]).to(torch.int64)
+
+ latent_model_input = latents.to(transformer_dtype)
+ latents_history_short = latents_history_short.to(transformer_dtype)
+ latents_history_mid = latents_history_mid.to(transformer_dtype)
+ latents_history_long = latents_history_long.to(transformer_dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ indices_hidden_states=indices_hidden_states,
+ indices_latents_history_short=indices_latents_history_short,
+ indices_latents_history_mid=indices_latents_history_mid,
+ indices_latents_history_long=indices_latents_history_long,
+ latents_history_short=latents_history_short,
+ latents_history_mid=latents_history_mid,
+ latents_history_long=latents_history_long,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ indices_hidden_states=indices_hidden_states,
+ indices_latents_history_short=indices_latents_history_short,
+ indices_latents_history_mid=indices_latents_history_mid,
+ indices_latents_history_long=indices_latents_history_long,
+ latents_history_short=latents_history_short,
+ latents_history_mid=latents_history_mid,
+ latents_history_long=latents_history_long,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.config.is_cfg_zero_star:
+ noise_pred_text = noise_pred
+ positive_flat = noise_pred_text.view(batch_size, -1)
+ negative_flat = noise_uncond.view(batch_size, -1)
+
+ alpha = optimized_scale(positive_flat, negative_flat)
+ alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1)))
+ alpha = alpha.to(noise_pred_text.dtype)
+
+ if (stage_idx == 0 and i <= zero_steps) and use_zero_init:
+ noise_pred = noise_pred_text * 0.0
+ else:
+ noise_pred = noise_uncond * alpha + guidance_scale * (
+ noise_pred_text - noise_uncond * alpha
+ )
+ else:
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ extra_kwargs = (
+ {
+ "cur_sampling_step": i,
+ "dmd_noisy_tensor": start_point_list[stage_idx]
+ if start_point_list is not None
+ else None,
+ "dmd_sigmas": self.scheduler.sigmas,
+ "dmd_timesteps": self.scheduler.timesteps,
+ "all_timesteps": timesteps,
+ }
+ if self.config.is_distilled
+ else {}
+ )
+
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ return_dict=False,
+ **extra_kwargs,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop(
+ "negative_prompt_embeds", negative_prompt_embeds
+ )
+
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if keep_first_frame and (
+ (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk)
+ ):
+ image_latents = latents[:, :, 0:1, :, :]
+
+ total_generated_latent_frames += latents.shape[2]
+ history_latents = torch.cat([history_latents, latents], dim=2)
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
+ current_latents = (
+ real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std
+ + latents_mean
+ )
+ current_video = self.vae.decode(current_latents, return_dict=False)[0]
+
+ if history_video is None:
+ history_video = current_video
+ else:
+ history_video = torch.cat([history_video, current_video], dim=2)
+
+ self._current_timestep = None
+
+ if output_type != "latent":
+ generated_frames = history_video.size(2)
+ generated_frames = (
+ generated_frames - 1
+ ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ history_video = history_video[:, :, :generated_frames]
+ video = self.video_processor.postprocess_video(history_video, output_type=output_type)
+ else:
+ video = real_history_latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HeliosPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/helios/pipeline_output.py b/src/diffusers/pipelines/helios/pipeline_output.py
new file mode 100644
index 000000000000..08546289ef4c
--- /dev/null
+++ b/src/diffusers/pipelines/helios/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class HeliosPipelineOutput(BaseOutput):
+ r"""
+ Output class for Helios pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 4199e75bf331..c7101d1b0401 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -61,6 +61,8 @@
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
_import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
+ _import_structure["scheduling_helios"] = ["HeliosScheduler"]
+ _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"]
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
@@ -164,6 +166,8 @@
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
+ from .scheduling_helios import HeliosScheduler
+ from .scheduling_helios_dmd import HeliosDMDScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py
new file mode 100644
index 000000000000..ed35245c9db3
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_helios.py
@@ -0,0 +1,867 @@
+# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Literal
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..schedulers.scheduling_utils import SchedulerMixin
+from ..utils import BaseOutput, deprecate
+
+
+@dataclass
+class HeliosSchedulerOutput(BaseOutput):
+ prev_sample: torch.FloatTensor
+ model_outputs: torch.FloatTensor | None = None
+ last_sample: torch.FloatTensor | None = None
+ this_order: int | None = None
+
+
+class HeliosScheduler(SchedulerMixin, ConfigMixin):
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0, # Following Stable diffusion 3,
+ stages: int = 3,
+ stage_range: list = [0, 1 / 3, 2 / 3, 1],
+ gamma: float = 1 / 3,
+ # For UniPC
+ thresholding: bool = False,
+ prediction_type: str = "flow_prediction",
+ solver_order: int = 2,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: list[int] = [],
+ solver_p: SchedulerMixin = None,
+ use_flow_sigmas: bool = True,
+ scheduler_type: str = "unipc", # ["euler", "unipc"]
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential", "linear"] = "exponential",
+ ):
+ self.timestep_ratios = {} # The timestep ratio for each stage
+ self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
+ self.sigmas_per_stage = {} # always uniform [1000, 0]
+ self.start_sigmas = {} # for start point / upsample renoise
+ self.end_sigmas = {} # for end point
+ self.ori_start_sigmas = {}
+
+ # self.init_sigmas()
+ self.init_sigmas_for_each_stage()
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+ self.gamma = gamma
+
+ if solver_type not in ["bh1", "bh2"]:
+ if solver_type in ["midpoint", "heun", "logrho"]:
+ self.register_to_config(solver_type="bh2")
+ else:
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
+
+ self.predict_x0 = predict_x0
+ self.model_outputs = [None] * solver_order
+ self.timestep_list = [None] * solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = disable_corrector
+ self.solver_p = solver_p
+ self.last_sample = None
+ self._step_index = None
+ self._begin_index = None
+
+ def init_sigmas(self):
+ """
+ initialize the global timesteps and sigmas
+ """
+ num_train_timesteps = self.config.num_train_timesteps
+ shift = self.config.shift
+
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
+ sigmas = torch.from_numpy(sigmas)
+ timesteps = (sigmas * num_train_timesteps).clone()
+
+ self._step_index = None
+ self._begin_index = None
+ self.timesteps = timesteps
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ def init_sigmas_for_each_stage(self):
+ """
+ Init the timesteps for each stage
+ """
+ self.init_sigmas()
+
+ stage_distance = []
+ stages = self.config.stages
+ training_steps = self.config.num_train_timesteps
+ stage_range = self.config.stage_range
+
+ # Init the start and end point of each stage
+ for i_s in range(stages):
+ # To decide the start and ends point
+ start_indice = int(stage_range[i_s] * training_steps)
+ start_indice = max(start_indice, 0)
+ end_indice = int(stage_range[i_s + 1] * training_steps)
+ end_indice = min(end_indice, training_steps)
+ start_sigma = self.sigmas[start_indice].item()
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
+ self.ori_start_sigmas[i_s] = start_sigma
+
+ if i_s != 0:
+ ori_sigma = 1 - start_sigma
+ gamma = self.config.gamma
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
+ start_sigma = 1 - corrected_sigma
+
+ stage_distance.append(start_sigma - end_sigma)
+ self.start_sigmas[i_s] = start_sigma
+ self.end_sigmas[i_s] = end_sigma
+
+ # Determine the ratio of each stage according to flow length
+ tot_distance = sum(stage_distance)
+ for i_s in range(stages):
+ if i_s == 0:
+ start_ratio = 0.0
+ else:
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
+ if i_s == stages - 1:
+ end_ratio = 0.9999999999999999
+ else:
+ end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
+
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
+
+ # Determine the timesteps and sigmas for each stage
+ for i_s in range(stages):
+ timestep_ratio = self.timestep_ratios[i_s]
+ # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
+ timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
+ timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
+ self.timesteps_per_stage[i_s] = (
+ timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
+ )
+ stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ stage_index: int | None = None,
+ device: str | torch.device = None,
+ sigmas: bool | None = None,
+ mu: bool | None = None,
+ is_amplify_first_chunk: bool = False,
+ ):
+ """
+ Setting the timesteps and sigmas for each stage
+ """
+ if self.config.scheduler_type == "dmd":
+ if is_amplify_first_chunk:
+ num_inference_steps = num_inference_steps * 2 + 1
+ else:
+ num_inference_steps = num_inference_steps + 1
+
+ self.num_inference_steps = num_inference_steps
+ self.init_sigmas()
+
+ if self.config.stages == 1:
+ if sigmas is None:
+ sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
+ np.float32
+ )
+ if self.config.shift != 1.0:
+ assert not self.config.use_dynamic_shifting
+ sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = torch.from_numpy(sigmas)
+ else:
+ stage_timesteps = self.timesteps_per_stage[stage_index]
+ timesteps = np.linspace(
+ stage_timesteps[0].item(),
+ stage_timesteps[-1].item(),
+ num_inference_steps,
+ )
+
+ stage_sigmas = self.sigmas_per_stage[stage_index]
+ ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
+ sigmas = torch.from_numpy(ratios)
+
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+ self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
+
+ self._step_index = None
+ self.reset_scheduler_history()
+
+ if self.config.scheduler_type == "dmd":
+ self.timesteps = self.timesteps[:-1]
+ self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
+
+ if self.config.use_dynamic_shifting:
+ assert self.config.shift == 1.0
+ self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
+ if self.config.stages == 1:
+ self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
+ else:
+ self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
+ self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
+ )
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ """
+ Apply time shifting to the sigmas.
+
+ Args:
+ mu (`float`):
+ The mu parameter for the time shift.
+ sigma (`float`):
+ The sigma parameter for the time shift.
+ t (`torch.Tensor`):
+ The input timesteps.
+
+ Returns:
+ `torch.Tensor`:
+ The time-shifted timesteps.
+ """
+ if self.config.time_shift_type == "exponential":
+ return self._time_shift_exponential(mu, sigma, t)
+ elif self.config.time_shift_type == "linear":
+ return self._time_shift_linear(mu, sigma, t)
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
+ def _time_shift_exponential(self, mu, sigma, t):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
+ def _time_shift_linear(self, mu, sigma, t):
+ return mu / (mu + (1 / t - 1) ** sigma)
+
+ # ---------------------------------- Euler ----------------------------------
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step_euler(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float | torch.FloatTensor = None,
+ sample: torch.FloatTensor = None,
+ generator: torch.Generator | None = None,
+ sigma: torch.FloatTensor | None = None,
+ sigma_next: torch.FloatTensor | None = None,
+ return_dict: bool = True,
+ ) -> HeliosSchedulerOutput | tuple:
+ assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None"
+
+ if sigma is None and sigma_next is None:
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._step_index = 0
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ if sigma is None and sigma_next is None:
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+
+ prev_sample = sample + (sigma_next - sigma) * model_output
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return HeliosSchedulerOutput(prev_sample=prev_sample)
+
+ # ---------------------------------- UniPC ----------------------------------
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = torch.clamp(sigma, min=1e-8)
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
+
+ return alpha_t, sigma_t
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ sigma: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ r"""
+ Convert the model output to the corresponding type the UniPC algorithm needs.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError("missing `sample` as a required keyword argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ flag = False
+ if sigma is None:
+ flag = True
+ sigma = self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+
+ if self.predict_x0:
+ if self.config.prediction_type == "epsilon":
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ if flag:
+ sigma_t = self.sigmas[self.step_index]
+ else:
+ sigma_t = sigma
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ else:
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the UniPCMultistepScheduler."
+ )
+
+ def multistep_uni_p_bh_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor = None,
+ order: int = None,
+ sigma: torch.Tensor = None,
+ sigma_next: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model at the current timestep.
+ prev_timestep (`int`):
+ The previous discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ order (`int`):
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError("missing `sample` as a required keyword argument")
+ if order is None:
+ if len(args) > 2:
+ order = args[2]
+ else:
+ raise ValueError("missing `order` as a required keyword argument")
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+ model_output_list = self.model_outputs
+
+ s0 = self.timestep_list[-1]
+ m0 = model_output_list[-1]
+ x = sample
+
+ if self.solver_p:
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
+ return x_t
+
+ if sigma_next is None and sigma is None:
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ else:
+ sigma_t, sigma_s0 = sigma_next, sigma
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - i
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
+ else:
+ D1s = None
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def multistep_uni_c_bh_update(
+ self,
+ this_model_output: torch.Tensor,
+ *args,
+ last_sample: torch.Tensor = None,
+ this_sample: torch.Tensor = None,
+ order: int = None,
+ sigma_before: torch.Tensor = None,
+ sigma: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniC (B(h) version).
+
+ Args:
+ this_model_output (`torch.Tensor`):
+ The model outputs at `x_t`.
+ this_timestep (`int`):
+ The current timestep `t`.
+ last_sample (`torch.Tensor`):
+ The generated sample before the last predictor `x_{t-1}`.
+ this_sample (`torch.Tensor`):
+ The generated sample after the last predictor `x_{t}`.
+ order (`int`):
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
+
+ Returns:
+ `torch.Tensor`:
+ The corrected sample tensor at the current timestep.
+ """
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
+ if last_sample is None:
+ if len(args) > 1:
+ last_sample = args[1]
+ else:
+ raise ValueError("missing `last_sample` as a required keyword argument")
+ if this_sample is None:
+ if len(args) > 2:
+ this_sample = args[2]
+ else:
+ raise ValueError("missing `this_sample` as a required keyword argument")
+ if order is None:
+ if len(args) > 3:
+ order = args[3]
+ else:
+ raise ValueError("missing `order` as a required keyword argument")
+ if this_timestep is not None:
+ deprecate(
+ "this_timestep",
+ "1.0.0",
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
+ )
+
+ model_output_list = self.model_outputs
+
+ m0 = model_output_list[-1]
+ x = last_sample
+ x_t = this_sample
+ model_t = this_model_output
+
+ if sigma_before is None and sigma is None:
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
+ else:
+ sigma_t, sigma_s0 = sigma, sigma_before
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+ device = this_sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = self.step_index - (i + 1)
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ else:
+ D1s = None
+
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def step_unipc(
+ self,
+ model_output: torch.Tensor,
+ timestep: int | torch.Tensor = None,
+ sample: torch.Tensor = None,
+ return_dict: bool = True,
+ model_outputs: list = None,
+ timestep_list: list = None,
+ sigma_before: torch.Tensor = None,
+ sigma: torch.Tensor = None,
+ sigma_next: torch.Tensor = None,
+ cus_step_index: int = None,
+ cus_lower_order_num: int = None,
+ cus_this_order: int = None,
+ cus_last_sample: torch.Tensor = None,
+ ) -> HeliosSchedulerOutput | tuple:
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if cus_step_index is None:
+ if self.step_index is None:
+ self._step_index = 0
+ else:
+ self._step_index = cus_step_index
+
+ if cus_lower_order_num is not None:
+ self.lower_order_nums = cus_lower_order_num
+
+ if cus_this_order is not None:
+ self.this_order = cus_this_order
+
+ if cus_last_sample is not None:
+ self.last_sample = cus_last_sample
+
+ use_corrector = (
+ self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
+ )
+
+ # Convert model output using the proper conversion method
+ model_output_convert = self.convert_model_output(model_output, sample=sample, sigma=sigma)
+
+ if model_outputs is not None and timestep_list is not None:
+ self.model_outputs = model_outputs[:-1]
+ self.timestep_list = timestep_list[:-1]
+
+ if use_corrector:
+ sample = self.multistep_uni_c_bh_update(
+ this_model_output=model_output_convert,
+ last_sample=self.last_sample,
+ this_sample=sample,
+ order=self.this_order,
+ sigma_before=sigma_before,
+ sigma=sigma,
+ )
+
+ if model_outputs is not None and timestep_list is not None:
+ model_outputs[-1] = model_output_convert
+ self.model_outputs = model_outputs[1:]
+ self.timestep_list = timestep_list[1:]
+ else:
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.timestep_list[i] = self.timestep_list[i + 1]
+ self.model_outputs[-1] = model_output_convert
+ self.timestep_list[-1] = timestep
+
+ if self.config.lower_order_final:
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
+ else:
+ this_order = self.config.solver_order
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
+ assert self.this_order > 0
+
+ self.last_sample = sample
+ prev_sample = self.multistep_uni_p_bh_update(
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
+ sample=sample,
+ order=self.this_order,
+ sigma=sigma,
+ sigma_next=sigma_next,
+ )
+
+ if cus_lower_order_num is None:
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # upon completion increase step index by one
+ if cus_step_index is None:
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample, model_outputs, self.last_sample, self.this_order)
+
+ return HeliosSchedulerOutput(
+ prev_sample=prev_sample,
+ model_outputs=model_outputs,
+ last_sample=self.last_sample,
+ this_order=self.this_order,
+ )
+
+ # ---------------------------------- Merge ----------------------------------
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float | torch.FloatTensor = None,
+ sample: torch.FloatTensor = None,
+ generator: torch.Generator | None = None,
+ return_dict: bool = True,
+ ) -> HeliosSchedulerOutput | tuple:
+ if self.config.scheduler_type == "euler":
+ return self.step_euler(
+ model_output=model_output,
+ timestep=timestep,
+ sample=sample,
+ generator=generator,
+ return_dict=return_dict,
+ )
+ elif self.config.scheduler_type == "unipc":
+ return self.step_unipc(
+ model_output=model_output,
+ timestep=timestep,
+ sample=sample,
+ return_dict=return_dict,
+ )
+ else:
+ raise NotImplementedError
+
+ def reset_scheduler_history(self):
+ self.model_outputs = [None] * self.config.solver_order
+ self.timestep_list = [None] * self.config.solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = self.config.disable_corrector
+ self.solver_p = self.config.solver_p
+ self.last_sample = None
+ self._step_index = None
+ self._begin_index = None
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py
new file mode 100644
index 000000000000..1f4afa0e3128
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_helios_dmd.py
@@ -0,0 +1,331 @@
+# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Literal
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..schedulers.scheduling_utils import SchedulerMixin
+from ..utils import BaseOutput
+
+
+@dataclass
+class HeliosDMDSchedulerOutput(BaseOutput):
+ prev_sample: torch.FloatTensor
+ model_outputs: torch.FloatTensor | None = None
+ last_sample: torch.FloatTensor | None = None
+ this_order: int | None = None
+
+
+class HeliosDMDScheduler(SchedulerMixin, ConfigMixin):
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0, # Following Stable diffusion 3,
+ stages: int = 3,
+ stage_range: list = [0, 1 / 3, 2 / 3, 1],
+ gamma: float = 1 / 3,
+ prediction_type: str = "flow_prediction",
+ use_flow_sigmas: bool = True,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential", "linear"] = "linear",
+ ):
+ self.timestep_ratios = {} # The timestep ratio for each stage
+ self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
+ self.sigmas_per_stage = {} # always uniform [1000, 0]
+ self.start_sigmas = {} # for start point / upsample renoise
+ self.end_sigmas = {} # for end point
+ self.ori_start_sigmas = {}
+
+ # self.init_sigmas()
+ self.init_sigmas_for_each_stage()
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+ self.gamma = gamma
+
+ self.last_sample = None
+ self._step_index = None
+ self._begin_index = None
+
+ def init_sigmas(self):
+ """
+ initialize the global timesteps and sigmas
+ """
+ num_train_timesteps = self.config.num_train_timesteps
+ shift = self.config.shift
+
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
+ sigmas = torch.from_numpy(sigmas)
+ timesteps = (sigmas * num_train_timesteps).clone()
+
+ self._step_index = None
+ self._begin_index = None
+ self.timesteps = timesteps
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ def init_sigmas_for_each_stage(self):
+ """
+ Init the timesteps for each stage
+ """
+ self.init_sigmas()
+
+ stage_distance = []
+ stages = self.config.stages
+ training_steps = self.config.num_train_timesteps
+ stage_range = self.config.stage_range
+
+ # Init the start and end point of each stage
+ for i_s in range(stages):
+ # To decide the start and ends point
+ start_indice = int(stage_range[i_s] * training_steps)
+ start_indice = max(start_indice, 0)
+ end_indice = int(stage_range[i_s + 1] * training_steps)
+ end_indice = min(end_indice, training_steps)
+ start_sigma = self.sigmas[start_indice].item()
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
+ self.ori_start_sigmas[i_s] = start_sigma
+
+ if i_s != 0:
+ ori_sigma = 1 - start_sigma
+ gamma = self.config.gamma
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
+ start_sigma = 1 - corrected_sigma
+
+ stage_distance.append(start_sigma - end_sigma)
+ self.start_sigmas[i_s] = start_sigma
+ self.end_sigmas[i_s] = end_sigma
+
+ # Determine the ratio of each stage according to flow length
+ tot_distance = sum(stage_distance)
+ for i_s in range(stages):
+ if i_s == 0:
+ start_ratio = 0.0
+ else:
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
+ if i_s == stages - 1:
+ end_ratio = 0.9999999999999999
+ else:
+ end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
+
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
+
+ # Determine the timesteps and sigmas for each stage
+ for i_s in range(stages):
+ timestep_ratio = self.timestep_ratios[i_s]
+ # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
+ timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
+ timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
+ self.timesteps_per_stage[i_s] = (
+ timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
+ )
+ stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ stage_index: int | None = None,
+ device: str | torch.device = None,
+ sigmas: bool | None = None,
+ mu: bool | None = None,
+ is_amplify_first_chunk: bool = False,
+ ):
+ """
+ Setting the timesteps and sigmas for each stage
+ """
+ if is_amplify_first_chunk:
+ num_inference_steps = num_inference_steps * 2 + 1
+ else:
+ num_inference_steps = num_inference_steps + 1
+
+ self.num_inference_steps = num_inference_steps
+ self.init_sigmas()
+
+ if self.config.stages == 1:
+ if sigmas is None:
+ sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
+ np.float32
+ )
+ if self.config.shift != 1.0:
+ assert not self.config.use_dynamic_shifting
+ sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = torch.from_numpy(sigmas)
+ else:
+ stage_timesteps = self.timesteps_per_stage[stage_index]
+ timesteps = np.linspace(
+ stage_timesteps[0].item(),
+ stage_timesteps[-1].item(),
+ num_inference_steps,
+ )
+
+ stage_sigmas = self.sigmas_per_stage[stage_index]
+ ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
+ sigmas = torch.from_numpy(ratios)
+
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+ self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
+
+ self._step_index = None
+ self.reset_scheduler_history()
+
+ self.timesteps = self.timesteps[:-1]
+ self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
+
+ if self.config.use_dynamic_shifting:
+ assert self.config.shift == 1.0
+ self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
+ if self.config.stages == 1:
+ self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
+ else:
+ self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
+ self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
+ )
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ """
+ Apply time shifting to the sigmas.
+
+ Args:
+ mu (`float`):
+ The mu parameter for the time shift.
+ sigma (`float`):
+ The sigma parameter for the time shift.
+ t (`torch.Tensor`):
+ The input timesteps.
+
+ Returns:
+ `torch.Tensor`:
+ The time-shifted timesteps.
+ """
+ if self.config.time_shift_type == "exponential":
+ return self._time_shift_exponential(mu, sigma, t)
+ elif self.config.time_shift_type == "linear":
+ return self._time_shift_linear(mu, sigma, t)
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
+ def _time_shift_exponential(self, mu, sigma, t):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
+ def _time_shift_linear(self, mu, sigma, t):
+ return mu / (mu + (1 / t - 1) ** sigma)
+
+ # ---------------------------------- For DMD ----------------------------------
+ def add_noise(self, original_samples, noise, timestep, sigmas, timesteps):
+ sigmas = sigmas.to(noise.device)
+ timesteps = timesteps.to(noise.device)
+ timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
+ sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
+ sample = (1 - sigma) * original_samples + sigma * noise
+ return sample.type_as(noise)
+
+ def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps):
+ # use higher precision for calculations
+ original_dtype = flow_pred.dtype
+ device = flow_pred.device
+ flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps))
+
+ timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
+ sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
+ x0_pred = xt - sigma_t * flow_pred
+ return x0_pred.to(original_dtype)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float | torch.FloatTensor = None,
+ sample: torch.FloatTensor = None,
+ generator: torch.Generator | None = None,
+ return_dict: bool = True,
+ cur_sampling_step: int = 0,
+ dmd_noisy_tensor: torch.FloatTensor | None = None,
+ dmd_sigmas: torch.FloatTensor | None = None,
+ dmd_timesteps: torch.FloatTensor | None = None,
+ all_timesteps: torch.FloatTensor | None = None,
+ ) -> HeliosDMDSchedulerOutput | tuple:
+ pred_image_or_video = self.convert_flow_pred_to_x0(
+ flow_pred=model_output,
+ xt=sample,
+ timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device),
+ sigmas=dmd_sigmas,
+ timesteps=dmd_timesteps,
+ )
+ if cur_sampling_step < len(all_timesteps) - 1:
+ prev_sample = self.add_noise(
+ pred_image_or_video,
+ dmd_noisy_tensor,
+ torch.full(
+ (model_output.shape[0],),
+ all_timesteps[cur_sampling_step + 1],
+ dtype=torch.long,
+ device=model_output.device,
+ ),
+ sigmas=dmd_sigmas,
+ timesteps=dmd_timesteps,
+ )
+ else:
+ prev_sample = pred_image_or_video
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return HeliosDMDSchedulerOutput(prev_sample=prev_sample)
+
+ def reset_scheduler_history(self):
+ self._step_index = None
+ self._begin_index = None
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 4e402921aa5f..3a4aecd24f90 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1031,6 +1031,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class HeliosTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -2743,6 +2758,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class HeliosDMDScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class HeliosScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 8758c549ca77..b86b5d2c6f4d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1352,6 +1352,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class HeliosPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class HeliosPyramidPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class HiDreamImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/lora/test_lora_layers_helios.py b/tests/lora/test_lora_layers_helios.py
new file mode 100644
index 000000000000..fbcc3b808eee
--- /dev/null
+++ b/tests/lora/test_lora_layers_helios.py
@@ -0,0 +1,120 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, HeliosPipeline, HeliosTransformer3DModel
+
+from ..testing_utils import floats_tensor, require_peft_backend, skip_mps
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class HeliosLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = HeliosPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 32,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_dim": (4, 4, 4),
+ "has_multi_term_memory_patch": True,
+ "guidance_cross_attn": True,
+ "zero_history_timestep": True,
+ "is_amplify_history": False,
+ }
+ transformer_cls = HeliosTransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ supports_text_encoder_loras = False
+
+ @property
+ def output_shape(self):
+ return (1, 33, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Helios.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Helios.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Helios.")
+ def test_modify_padding_mode(self):
+ pass
diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py
new file mode 100644
index 000000000000..c365c258e596
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_helios.py
@@ -0,0 +1,168 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+import torch
+
+from diffusers import HeliosTransformer3DModel
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..testing_utils import (
+ AttentionTesterMixin,
+ BaseModelTesterConfig,
+ MemoryTesterMixin,
+ ModelTesterMixin,
+ TorchCompileTesterMixin,
+ TrainingTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class HeliosTransformer3DTesterConfig(BaseModelTesterConfig):
+ @property
+ def model_class(self):
+ return HeliosTransformer3DModel
+
+ @property
+ def pretrained_model_name_or_path(self):
+ return "hf-internal-testing/tiny-helios-base-transformer"
+
+ @property
+ def output_shape(self) -> tuple[int, ...]:
+ return (4, 2, 16, 16)
+
+ @property
+ def input_shape(self) -> tuple[int, ...]:
+ return (4, 2, 16, 16)
+
+ @property
+ def main_input_name(self) -> str:
+ return "hidden_states"
+
+ @property
+ def generator(self):
+ return torch.Generator("cpu").manual_seed(0)
+
+ def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
+ return {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_dim": (4, 4, 4),
+ "has_multi_term_memory_patch": True,
+ "guidance_cross_attn": True,
+ "zero_history_timestep": True,
+ "is_amplify_history": False,
+ }
+
+ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = randn_tensor(
+ (batch_size, num_channels, num_frames, height, width),
+ generator=self.generator,
+ device=torch_device,
+ )
+ timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device)
+ encoder_hidden_states = randn_tensor(
+ (batch_size, sequence_length, text_encoder_embedding_dim),
+ generator=self.generator,
+ device=torch_device,
+ )
+ indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device)
+ indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device)
+ indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device)
+ indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device)
+ latents_history_short = randn_tensor(
+ (batch_size, num_channels, num_frames - 1, height, width),
+ generator=self.generator,
+ device=torch_device,
+ )
+ latents_history_mid = randn_tensor(
+ (batch_size, num_channels, num_frames - 1, height, width),
+ generator=self.generator,
+ device=torch_device,
+ )
+ latents_history_long = randn_tensor(
+ (batch_size, num_channels, (num_frames - 1) * 4, height, width),
+ generator=self.generator,
+ device=torch_device,
+ )
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "indices_hidden_states": indices_hidden_states,
+ "indices_latents_history_short": indices_latents_history_short,
+ "indices_latents_history_mid": indices_latents_history_mid,
+ "indices_latents_history_long": indices_latents_history_long,
+ "latents_history_short": latents_history_short,
+ "latents_history_mid": latents_history_mid,
+ "latents_history_long": latents_history_long,
+ }
+
+
+class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
+ """Core model tests for Helios Transformer 3D."""
+
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
+ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
+ # Skip: fp16/bf16 require very high atol to pass, providing little signal.
+ # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
+ pytest.skip("Tolerance requirements too high for meaningful test")
+
+
+class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
+ """Memory optimization tests for Helios Transformer 3D."""
+
+
+class TestHeliosTransformer3DTraining(HeliosTransformer3DTesterConfig, TrainingTesterMixin):
+ """Training tests for Helios Transformer 3D."""
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HeliosTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, AttentionTesterMixin):
+ """Attention processor tests for Helios Transformer 3D."""
+
+
+class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin):
+ """Torch compile tests for Helios Transformer 3D."""
+
+ @pytest.mark.xfail(
+ reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079"
+ )
+ def test_torch_compile_recompilation_and_graph_break(self):
+ super().test_torch_compile_recompilation_and_graph_break()
diff --git a/tests/pipelines/helios/__init__.py b/tests/pipelines/helios/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py
new file mode 100644
index 000000000000..b8ee99085036
--- /dev/null
+++ b/tests/pipelines/helios/test_helios.py
@@ -0,0 +1,172 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, HeliosPipeline, HeliosScheduler, HeliosTransformer3DModel
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = HeliosPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = HeliosScheduler(stage_range=[0, 1], stages=1, use_dynamic_shifting=True)
+ config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
+ text_encoder = T5EncoderModel(config)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = HeliosTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_dim=(4, 4, 4),
+ has_multi_term_memory_patch=True,
+ guidance_cross_attn=True,
+ zero_history_timestep=True,
+ is_amplify_history=False,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (33, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4529, 0.4527, 0.4499, 0.4542, 0.4528, 0.4524, 0.4531, 0.4534, 0.5328,
+ 0.5340, 0.5012, 0.5135, 0.5322, 0.5203, 0.5144, 0.5101])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ # Override to set a more lenient max diff threshold.
+ def test_save_load_float16(self):
+ super().test_save_load_float16(expected_max_diff=0.03)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Optional components not applicable for Helios")
+ def test_save_load_optional_components(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class HeliosPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_helios(self):
+ pass