From fc4a76ea3e9f987a54608b82aadf0e2a357afa8d Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 24 Feb 2026 14:20:17 +0000 Subject: [PATCH 001/107] [1/N] add helios --- 0_temp_helios_test/infer_helios.py | 576 +++++++ 0_temp_helios_test/requirements.txt | 35 + 0_temp_helios_test/stage-1_i2v.sh | 14 + 0_temp_helios_test/stage-1_t2v.sh | 14 + 0_temp_helios_test/stage-1_v2v.sh | 14 + 0_temp_helios_test/stage-2_i2v.sh | 15 + 0_temp_helios_test/stage-2_t2v.sh | 15 + 0_temp_helios_test/stage-2_v2v.sh | 15 + 0_temp_helios_test/stage-3_i2v.sh | 17 + 0_temp_helios_test/stage-3_t2v.sh | 16 + 0_temp_helios_test/stage-3_v2v.sh | 17 + docs/source/en/_toctree.yml | 9 +- .../en/api/models/helios_transformer3d.md | 35 + docs/source/en/api/pipelines/helios.md | 463 +++++ docs/source/en/api/schedulers/helios_unipc.md | 20 + docs/source/en/using-diffusers/helios.md | 67 + docs/source/zh/_toctree.yml | 2 + docs/source/zh/community_projects.md | 8 + docs/source/zh/using-diffusers/helios.md | 68 + src/diffusers/__init__.py | 6 + src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 274 +++ src/diffusers/loaders/peft.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_helios.py | 943 ++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 2 + src/diffusers/pipelines/helios/__init__.py | 46 + .../pipelines/helios/pipeline_helios.py | 1519 +++++++++++++++++ .../pipelines/helios/pipeline_output.py | 20 + src/diffusers/schedulers/__init__.py | 2 + .../schedulers/scheduling_helios_unipc.py | 817 +++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_helios.py | 0 tests/pipelines/helios/__init__.py | 0 tests/pipelines/helios/test_helios.py | 13 + 38 files changed, 5113 insertions(+), 1 deletion(-) create mode 100644 0_temp_helios_test/infer_helios.py create mode 100644 0_temp_helios_test/requirements.txt create mode 100644 0_temp_helios_test/stage-1_i2v.sh create mode 100644 0_temp_helios_test/stage-1_t2v.sh create mode 100644 0_temp_helios_test/stage-1_v2v.sh create mode 100644 0_temp_helios_test/stage-2_i2v.sh create mode 100644 0_temp_helios_test/stage-2_t2v.sh create mode 100644 0_temp_helios_test/stage-2_v2v.sh create mode 100644 0_temp_helios_test/stage-3_i2v.sh create mode 100644 0_temp_helios_test/stage-3_t2v.sh create mode 100644 0_temp_helios_test/stage-3_v2v.sh create mode 100644 docs/source/en/api/models/helios_transformer3d.md create mode 100644 docs/source/en/api/pipelines/helios.md create mode 100644 docs/source/en/api/schedulers/helios_unipc.md create mode 100644 docs/source/en/using-diffusers/helios.md create mode 100644 docs/source/zh/using-diffusers/helios.md create mode 100644 src/diffusers/models/transformers/transformer_helios.py create mode 100644 src/diffusers/pipelines/helios/__init__.py create mode 100644 src/diffusers/pipelines/helios/pipeline_helios.py create mode 100644 src/diffusers/pipelines/helios/pipeline_output.py create mode 100644 src/diffusers/schedulers/scheduling_helios_unipc.py create mode 100644 tests/models/transformers/test_models_transformer_helios.py create mode 100644 tests/pipelines/helios/__init__.py create mode 100644 tests/pipelines/helios/test_helios.py diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py new file mode 100644 index 000000000000..43fbabf47392 --- /dev/null +++ b/0_temp_helios_test/infer_helios.py @@ -0,0 +1,576 @@ +import os + + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" +os.environ["HF_PARALLEL_LOADING_WORKERS"] = "8" + +import argparse +import time + +import pandas as pd +import torch +import torch.distributed as dist +from tqdm import tqdm + +from diffusers import HeliosTransformer3DModel +from diffusers import HeliosPipeline +from diffusers.schedulers.scheduling_helios_unipc import HeliosUniPCScheduler + +from diffusers import ContextParallelConfig +from diffusers.models import AutoencoderKLWan +from diffusers.utils import export_to_video, load_image, load_video + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate video with model") + + # === Model paths === + parser.add_argument("--base_model_path", type=str, default="/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base") + parser.add_argument( + "--transformer_path", + type=str, + default="/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base", + ) + parser.add_argument( + "--lora_path", + type=str, + default=None, + ) + parser.add_argument("--output_folder", type=str, default="./output_helios") + parser.add_argument("--use_default_loader", action="store_true") + parser.add_argument("--enable_compile", action="store_true") + parser.add_argument("--low_vram_mode", action="store_true") + parser.add_argument("--enable_parallelism", action="store_true") + + # === Generation parameters === + # environment + parser.add_argument("--debug_mode", action="store_true") + parser.add_argument( + "--sample_type", + type=str, + default="t2v", + choices=["t2v", "i2v", "v2v"], + ) + parser.add_argument( + "--weight_dtype", + type=str, + default="bf16", + choices=["bf16", "fp16", "fp32"], + help="Data type for model weights.", + ) + parser.add_argument("--seed", type=int, default=42, help="Seed for random number generator.") + # base + parser.add_argument("--height", type=int, default=384) + parser.add_argument("--width", type=int, default=640) + parser.add_argument("--num_frames", type=int, default=73) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--guidance_scale", type=float, default=5.0) + parser.add_argument("--use_dynamic_shifting", action="store_true") + parser.add_argument("--vae_decode_type", type=str, default="default", choices=["default", "once", "default_fast"]) + # cfg zero + parser.add_argument("--use_cfg_zero_star", action="store_true") + parser.add_argument("--use_zero_init", action="store_true") + parser.add_argument("--zero_steps", type=int, default=1) + # stage 1 + parser.add_argument("--latent_window_size", type=int, default=9) + # stage 2 + parser.add_argument("--is_enable_stage2", action="store_true") + parser.add_argument("--stage2_num_stages", type=int, default=3) + parser.add_argument("--stage2_timestep_shift", type=float, default=1.0) + parser.add_argument("--stage2_scheduler_gamma", type=float, default=1 / 3) + parser.add_argument("--stage2_stage_range", type=int, nargs="+", default=[0, 1 / 3, 2 / 3, 1]) + parser.add_argument("--stage2_num_inference_steps_list", type=int, nargs="+", default=[20, 20, 20]) + # stage 3 + parser.add_argument("--is_enable_stage3", action="store_true") + parser.add_argument("--is_skip_first_section", action="store_true") + parser.add_argument("--is_amplify_first_chunk", action="store_true") + + # === Prompts === + parser.add_argument("--use_interpolate_prompt", action="store_true") + parser.add_argument("--interpolation_steps", type=int, default=3) + parser.add_argument( + "--image_path", + type=str, + default=None, + ) + parser.add_argument( + "--video_path", + type=str, + default=None, + ) + parser.add_argument( + "--prompt", + type=str, + default="A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + ) + parser.add_argument( + "--prompt_txt_path", + type=str, + default=None, + ) + parser.add_argument( + "--interactive_prompt_csv_path", + type=str, + default=None, + ) + parser.add_argument( + "--base_image_prompt_path", + type=str, + default=None, + ) + parser.add_argument( + "--image_prompt_csv_path", + type=str, + default=None, + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + assert not (args.low_vram_mode and args.enable_compile), ( + "low_vram_mode and enable_compile cannot be used together." + ) + + if args.weight_dtype == "fp32": + args.weight_dtype = torch.float32 + elif args.weight_dtype == "fp16": + args.weight_dtype = torch.float16 + else: + args.weight_dtype = torch.bfloat16 + + os.makedirs(args.output_folder, exist_ok=True) + + if dist.is_available() and "RANK" in os.environ: + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + device = torch.device("cuda", rank % torch.cuda.device_count()) + world_size = dist.get_world_size() + torch.cuda.set_device(device) + assert world_size == 1 or not args.low_vram_mode, "low_vram_mode is only for single GPU." + else: + rank = 0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + world_size = 1 + + prompt = None + image_path = None + video_path = None + interpolate_time_list = None + if args.sample_type == "t2v" and args.prompt is None: + prompt = "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film." + elif args.sample_type == "i2v" and (args.image_path is None and args.prompt is None): + image_path = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ) + prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + elif args.sample_type == "v2v" and (args.video_path is None and args.prompt is None): + video_path = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ) + prompt = "A robot standing on a mountain top. The sun is setting in the background." + else: + image_path = args.image_path + video_path = args.video_path + if args.interactive_prompt_csv_path is not None and args.use_interpolate_prompt: + with open(args.prompt, "r") as f: + lines = [line.strip() for line in f.readlines() if line.strip()] + interpolate_time_list = [] + prompt = [] + for line in lines: + parts = line.split(",", 1) + if len(parts) == 2: + time_value = int(parts[0].strip()) + prompt_text = parts[1].strip().strip('"') + + interpolate_time_list.append(time_value) + prompt.append(prompt_text) + else: + prompt = args.prompt + + transformer = HeliosTransformer3DModel.from_pretrained( + args.transformer_path, + subfolder="transformer", + torch_dtype=args.weight_dtype, + use_default_loader=args.use_default_loader, + ) + transformer.set_attention_backend("_flash_3_hub") + + vae = AutoencoderKLWan.from_pretrained( + args.base_model_path, + subfolder="vae", + torch_dtype=torch.float32, + ) + if args.is_enable_stage2: + scheduler = HeliosUniPCScheduler( + shift=args.stage2_timestep_shift, + stages=args.stage2_num_stages, + stage_range=args.stage2_stage_range, + gamma=args.stage2_scheduler_gamma, + ) + pipe = HeliosPipeline.from_pretrained( + args.base_model_path, + transformer=transformer, + vae=vae, + scheduler=scheduler, + torch_dtype=args.weight_dtype, + ) + else: + pipe = HeliosPipeline.from_pretrained( + args.base_model_path, + transformer=transformer, + vae=vae, + torch_dtype=args.weight_dtype, + ) + + if args.lora_path is not None: + pipe.load_lora_weights(args.lora_path, adapter_name="default") + pipe.set_adapters(["default"], adapter_weights=[1.0]) + + if args.vae_decode_type == "once": + pipe.vae.enable_tiling() + + if args.enable_compile: + pipe.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True) + + if args.low_vram_mode: + pipe.enable_group_offload( + onload_device=torch.device("cuda"), + offload_device=torch.device("cpu"), + # offload_type="leaf_level", + offload_type="block_level", + num_blocks_per_group=1, + use_stream=True, + record_stream=True, + ) + else: + pipe = pipe.to(device) + + if world_size > 1 and args.enable_parallelism: + transformer.set_attention_backend("flash") + pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size)) + + if args.debug_mode: + + def parse_list_input(input_string): + input_string = input_string.strip("[]").strip() + if "," in input_string: + return [int(x.strip()) for x in input_string.split(",") if x.strip()] + else: + return [int(x.strip()) for x in input_string.split() if x.strip()] + + while True: + user_input = input("Please enter stage2_num_inference_steps_list (e.g., 10 20 30): ").strip() + + if user_input.lower() in ["q", "quit", "exit"]: + break + + try: + pyramid_steps = parse_list_input(user_input) + print(f"✅ Parsing successful: {pyramid_steps}") + except ValueError as e: + print(f"❌ Input format error: {e}") + print("Please re-enter...\n") + continue + + args.stage2_num_inference_steps_list = pyramid_steps + + with torch.no_grad(): + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + vae_decode_type=args.vae_decode_type, + # stage 1 + history_sizes=[16, 2, 1], + latent_window_size=args.latent_window_size, + is_keep_x0=True, + use_dynamic_shifting=args.use_dynamic_shifting, + # stage 2 + is_enable_stage2=args.is_enable_stage2, + stage2_num_stages=args.stage2_num_stages, + stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + scheduler_type="unipc", + # stage 3 + use_dmd=args.is_enable_stage3, + is_skip_first_section=args.is_skip_first_section, + is_amplify_first_chunk=args.is_amplify_first_chunk, + # cfg zero + use_cfg_zero_star=args.use_cfg_zero_star, + use_zero_init=args.use_zero_init, + zero_steps=args.zero_steps, + # i2v + image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] + + if not args.enable_parallelism or rank == 0: + file_count = len( + [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] + ) + output_path = os.path.join( + args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" + ) + export_to_video(output, output_path, fps=24) + elif args.prompt_txt_path is not None: + with open(args.prompt_txt_path, "r") as f: + prompt_list = [line.strip() for line in f.readlines() if line.strip()] + if not args.enable_parallelism: + prompt_list_with_idx = [(i, prompt) for i, prompt in enumerate(prompt_list)] + prompt_list_with_idx = prompt_list_with_idx[rank::world_size] + else: + prompt_list_with_idx = [(i, prompt) for i, prompt in enumerate(prompt_list)] + + for idx, prompt in tqdm(prompt_list_with_idx, desc="Processing prompts"): + output_path = os.path.join(args.output_folder, f"{idx}.mp4") + if os.path.exists(output_path): + print("skipping!") + continue + + with torch.no_grad(): + try: + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + vae_decode_type=args.vae_decode_type, + # stage 1 + history_sizes=[16, 2, 1], + latent_window_size=args.latent_window_size, + is_keep_x0=True, + use_dynamic_shifting=args.use_dynamic_shifting, + # stage 2 + is_enable_stage2=args.is_enable_stage2, + stage2_num_stages=args.stage2_num_stages, + stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + scheduler_type="unipc", + # stage 3 + use_dmd=args.is_enable_stage3, + is_skip_first_section=args.is_skip_first_section, + is_amplify_first_chunk=args.is_amplify_first_chunk, + # cfg zero + use_cfg_zero_star=args.use_cfg_zero_star, + use_zero_init=args.use_zero_init, + zero_steps=args.zero_steps, + # i2v + image=load_image(image_path).resize((args.width, args.height)) + if image_path is not None + else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] + except Exception: + continue + if not args.enable_parallelism or rank == 0: + export_to_video(output, output_path, fps=24) + elif args.interactive_prompt_csv_path is not None: + df = pd.read_csv(args.interactive_prompt_csv_path) + + df = df.sort_values(by=["id", "prompt_index"]) + all_video_ids = df["id"].unique() + + if not args.enable_parallelism: + my_video_ids = all_video_ids[rank::world_size] + else: + my_video_ids = all_video_ids + + for video_id in tqdm(my_video_ids, desc="Processing prompts"): + output_path = os.path.join(args.output_folder, f"{video_id}.mp4") + + if os.path.exists(output_path): + print(f"skipping {output_path}!") + continue + + group_df = df[df["id"] == video_id] + + if "refined_prompt" in df.columns: + prompt_list = group_df["refined_prompt"].fillna(group_df["prompt"]).tolist() + else: + prompt_list = group_df["prompt"].tolist() + interpolate_time_list = [7] * len(prompt_list) + + with torch.no_grad(): + try: + output = pipe( + prompt=prompt_list, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + vae_decode_type=args.vae_decode_type, + # stage 1 + history_sizes=[16, 2, 1], + latent_window_size=args.latent_window_size, + is_keep_x0=True, + use_dynamic_shifting=args.use_dynamic_shifting, + # stage 2 + is_enable_stage2=args.is_enable_stage2, + stage2_num_stages=args.stage2_num_stages, + stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + scheduler_type="unipc", + # stage 3 + use_dmd=args.is_enable_stage3, + is_skip_first_section=args.is_skip_first_section, + is_amplify_first_chunk=args.is_amplify_first_chunk, + # cfg zero + use_cfg_zero_star=args.use_cfg_zero_star, + use_zero_init=args.use_zero_init, + zero_steps=args.zero_steps, + # i2v + image=load_image(image_path).resize((args.width, args.height)) + if image_path is not None + else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] + except Exception: + continue + if not args.enable_parallelism or rank == 0: + export_to_video(output, output_path, fps=24) + elif args.image_prompt_csv_path is not None: + df = pd.read_csv(args.image_prompt_csv_path) + if not args.enable_parallelism: + df = df.iloc[rank::world_size] + + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing prompts"): + # output_path = os.path.join(args.output_folder, f"{idx}.mp4") + output_path = os.path.join(args.output_folder, f"{row['id']}.mp4") + if os.path.exists(output_path): + print("skipping!") + continue + + prompt = row.get("refined_prompt") or row["prompt"] + image_path = os.path.join(args.base_image_prompt_path, row["image_name"]) + + with torch.no_grad(): + try: + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + vae_decode_type=args.vae_decode_type, + # stage 1 + history_sizes=[16, 2, 1], + latent_window_size=args.latent_window_size, + is_keep_x0=True, + use_dynamic_shifting=args.use_dynamic_shifting, + # stage 2 + is_enable_stage2=args.is_enable_stage2, + stage2_num_stages=args.stage2_num_stages, + stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + scheduler_type="unipc", + # stage 3 + use_dmd=args.is_enable_stage3, + is_skip_first_section=args.is_skip_first_section, + is_amplify_first_chunk=args.is_amplify_first_chunk, + # cfg zero + use_cfg_zero_star=args.use_cfg_zero_star, + use_zero_init=args.use_zero_init, + zero_steps=args.zero_steps, + # i2v + image=load_image(image_path).resize((args.width, args.height)) + if image_path is not None + else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] + except Exception: + continue + if not args.enable_parallelism or rank == 0: + export_to_video(output, output_path, fps=24) + else: + with torch.no_grad(): + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + vae_decode_type=args.vae_decode_type, + # stage 1 + history_sizes=[16, 2, 1], + latent_window_size=args.latent_window_size, + is_keep_x0=True, + use_dynamic_shifting=args.use_dynamic_shifting, + # stage 2 + is_enable_stage2=args.is_enable_stage2, + stage2_num_stages=args.stage2_num_stages, + stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + scheduler_type="unipc", + # stage 3 + use_dmd=args.is_enable_stage3, + is_skip_first_section=args.is_skip_first_section, + is_amplify_first_chunk=args.is_amplify_first_chunk, + # cfg zero + use_cfg_zero_star=args.use_cfg_zero_star, + use_zero_init=args.use_zero_init, + zero_steps=args.zero_steps, + # i2v + image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] + + if not args.enable_parallelism or rank == 0: + file_count = len( + [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] + ) + output_path = os.path.join( + args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" + ) + export_to_video(output, output_path, fps=24) + + print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB") + + +if __name__ == "__main__": + main() diff --git a/0_temp_helios_test/requirements.txt b/0_temp_helios_test/requirements.txt new file mode 100644 index 000000000000..b1f777548ab3 --- /dev/null +++ b/0_temp_helios_test/requirements.txt @@ -0,0 +1,35 @@ +torch==2.7.1 +torchvision==0.22.1 +torchaudio==2.7.1 +triton==3.3.1 +# diffusers==0.36.0 +# transformers==4.57.6 +# sentence-transformers==5.2.3 +git+https://github.com/SHYuanBest/diffusers.git@test +git+https://github.com/huggingface/transformers.git +git+https://github.com/huggingface/sentence-transformers.git +accelerate==1.12.0 +deepspeed==0.18.4 +peft==0.18.1 +huggingface-hub==1.4.1 +zstandard==0.25.0 +wandb==0.23.0 +video-reader-rs==0.4.1 +opencv-python +gradio +spaces +moviepy +imageio-ffmpeg +ftfy +Jinja2 +einops +nvitop +packaging +ninja +omegaconf +mpi4py +hf-doc-builder +torchdata +kernels +loguru +tf_keras \ No newline at end of file diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh new file mode 100644 index 000000000000..36da9461fea9 --- /dev/null +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -0,0 +1,14 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --sample_type "i2v" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ + --use_dynamic_shifting \ + --output_folder "./output_helios/stage-1" + + # --use_default_loader \ + # --enable_compile \ + # --use_cfg_zero_star \ + # --use_zero_init \ + # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-1_t2v.sh b/0_temp_helios_test/stage-1_t2v.sh new file mode 100644 index 000000000000..8b2fea961533 --- /dev/null +++ b/0_temp_helios_test/stage-1_t2v.sh @@ -0,0 +1,14 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --sample_type "t2v" \ + --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --use_dynamic_shifting \ + --output_folder "./output_helios/stage-1" + + + # --use_default_loader \ + # --enable_compile \ + # --use_cfg_zero_star \ + # --use_zero_init \ + # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh new file mode 100644 index 000000000000..349c42d110da --- /dev/null +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -0,0 +1,14 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --sample_type "v2v" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --use_dynamic_shifting \ + --output_folder "./output_helios/stage-1" + + # --use_default_loader \ + # --enable_compile \ + # --use_cfg_zero_star \ + # --use_zero_init \ + # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh new file mode 100644 index 000000000000..56b3ad46d1a2 --- /dev/null +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -0,0 +1,15 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --sample_type "i2v" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ + --is_enable_stage2 \ + --use_dynamic_shifting \ + --use_cfg_zero_star \ + --use_zero_init \ + --zero_steps 1 \ + --output_folder "./output_helios/stage-2" + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh new file mode 100644 index 000000000000..078b7ebaf101 --- /dev/null +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -0,0 +1,15 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --sample_type "t2v" \ + --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --is_enable_stage2 \ + --use_dynamic_shifting \ + --use_cfg_zero_star \ + --use_zero_init \ + --zero_steps 1 \ + --output_folder "./output_helios/stage-2" + + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh new file mode 100644 index 000000000000..9158e9a7f68b --- /dev/null +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -0,0 +1,15 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --sample_type "v2v" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --is_enable_stage2 \ + --use_dynamic_shifting \ + --use_cfg_zero_star \ + --use_zero_init \ + --zero_steps 1 \ + --output_folder "./output_helios/stage-2" + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh new file mode 100644 index 000000000000..340c21ab639d --- /dev/null +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --sample_type "i2v" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ + --num_frames 240 \ + --guidance_scale 1.0 \ + --is_enable_stage2 \ + --stage2_num_inference_steps_list 2 2 2 \ + --is_enable_stage3 \ + --use_dynamic_shifting \ + --is_amplify_first_chunk \ + --output_folder "./output_helios/stage-3" + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/stage-3_t2v.sh new file mode 100644 index 000000000000..5240f3cbeff6 --- /dev/null +++ b/0_temp_helios_test/stage-3_t2v.sh @@ -0,0 +1,16 @@ +CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --sample_type "t2v" \ + --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --num_frames 240 \ + --guidance_scale 1.0 \ + --is_enable_stage2 \ + --stage2_num_inference_steps_list 2 2 2 \ + --is_enable_stage3 \ + --use_dynamic_shifting \ + --is_amplify_first_chunk \ + --output_folder "./output_helios/stage-3" + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh new file mode 100644 index 000000000000..1be947dd0f00 --- /dev/null +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --sample_type "v2v" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --num_frames 240 \ + --guidance_scale 1.0 \ + --is_enable_stage2 \ + --stage2_num_inference_steps_list 2 2 2 \ + --is_enable_stage3 \ + --use_dynamic_shifting \ + --is_amplify_first_chunk \ + --output_folder "./output_helios/stage-3" + + # --use_default_loader \ + # --enable_compile \ \ No newline at end of file diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 098660ec3f39..113d95def348 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -194,6 +194,8 @@ title: Model accelerators and hardware - isExpanded: false sections: + - local: using-diffusers/helios + title: Helios - local: using-diffusers/consisid title: ConsisID - local: using-diffusers/sdxl @@ -350,6 +352,8 @@ title: FluxTransformer2DModel - local: api/models/glm_image_transformer2d title: GlmImageTransformer2DModel + - local: api/models/helios_transformer3d + title: HeliosTransformer3DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -625,7 +629,6 @@ title: Image-to-image - local: api/pipelines/stable_diffusion/inpaint title: Inpainting - - local: api/pipelines/stable_diffusion/latent_upscale title: Latent upscaler - local: api/pipelines/stable_diffusion/ldm3d_diffusion @@ -674,6 +677,8 @@ title: ConsisID - local: api/pipelines/framepack title: Framepack + - local: api/pipelines/helios + title: Helios - local: api/pipelines/hunyuan_video title: HunyuanVideo - local: api/pipelines/hunyuan_video15 @@ -745,6 +750,8 @@ title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler + - local: api/schedulers/helios_unipc + title: HeliosUniPCScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/models/helios_transformer3d.md b/docs/source/en/api/models/helios_transformer3d.md new file mode 100644 index 000000000000..75f1c536aa3e --- /dev/null +++ b/docs/source/en/api/models/helios_transformer3d.md @@ -0,0 +1,35 @@ + + +# HeliosTransformer3DModel + +A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) by Peking University & ByteDance & etc. + +The model can be loaded with the following code snippet. + +```python +from diffusers import HeliosTransformer3DModel + +# Best Quality +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16) +# Intermediate Weight +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16) +# Best Efficiency +transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## HeliosTransformer3DModel + +[[autodoc]] HeliosTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md new file mode 100644 index 000000000000..e59c38211663 --- /dev/null +++ b/docs/source/en/api/pipelines/helios.md @@ -0,0 +1,463 @@ + + +
+
+ + LoRA + +
+
+ +# Helios + +[Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Xinwei Huang, Xiao Yang, Li Yuan. + +* We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios). + +The following Helios models are supported in Diffusers: + +- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality +- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight +- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency + +> [!TIP] +> Click on the Helios models in the right sidebar for more examples of video generation. + +### Optimizing Memory and Inference Speed + +The example below demonstrates how to generate a video from text optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + +The Helios model below requires ~19GB of VRAM. + +```py +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +# group-offloading +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.enable_group_offload( + onload_device=torch.device("cuda"), + offload_device=torch.device("cpu"), + offload_type="block_level", + num_blocks_per_group=1, + use_stream=True, + record_stream=True, +) + +prompt = """ +A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various +elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. +The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but +emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and +exploration. Medium shot focusing on the train window and the rushing scenery beyond. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + use_dynamic_shifting=True, +).frames[0] +export_to_video(output, "output.mp4", fps=24) +``` + + + + +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. + +```py +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +# attention backend +# pipeline.transformer.set_attention_backend("flash") +pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs + +# torch.compile +torch.backends.cudnn.benchmark = True +pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False) +pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) + +prompt = """ +A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various +elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. +The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but +emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and +exploration. Medium shot focusing on the train window and the rushing scenery beyond. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + use_dynamic_shifting=True, +).frames[0] +export_to_video(output, "output.mp4", fps=24) +``` + + + + + +### Generation with Helios-Base + +The example below demonstrates how to use Helios-Base to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Base", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various +elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. +The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but +emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and +exploration. Medium shot focusing on the train window and the rushing scenery beyond. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + num_inference_steps=50, + use_dynamic_shifting=True, +).frames[0] +export_to_video(output, "output_t2v.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=99, + num_inference_steps=50, + use_dynamic_shifting=True, +).frames[0] +export_to_video(output, "output_i2v.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=99, + num_inference_steps=50, + use_dynamic_shifting=True, +).frames[0] +export_to_video(output, "output_v2v.mp4", fps=24) +``` + + + + + +### Generation with Helios-Mid + +The example below demonstrates how to use Helios-Mid to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh/Helios-Mid", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various +elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. +The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but +emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and +exploration. Medium shot focusing on the train window and the rushing scenery beyond. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[20, 20, 20], + use_cfg_zero_star=True, + use_zero_init=True, + zero_steps=1, +).frames[0] +export_to_video(output, "output_t2v.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[20, 20, 20], + use_cfg_zero_star=True, + use_zero_init=True, + zero_steps=1, +).frames[0] +export_to_video(output, "output_i2v.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[20, 20, 20], + use_cfg_zero_star=True, + use_zero_init=True, + zero_steps=1, +).frames[0] +export_to_video(output, "output_v2v.mp4", fps=24) +``` + + + + + +### Generation with Helios-Distilled + +The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video. + + + + +```python +import torch +from diffusers import AutoModel, HeliosPipeline +from diffusers.utils import export_to_video, load_video, load_image + +vae = AutoModel.from_pretrained("BestWishYsh//Helios-Distilled", subfolder="vae", torch_dtype=torch.float32) + +pipeline = HeliosPipeline.from_pretrained( + "BestWishYsh//Helios-Distilled", + vae=vae, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +# For Text-to-Video +prompt = """ +A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various +elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. +The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but +emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and +exploration. Medium shot focusing on the train window and the rushing scenery beyond. +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[2, 2, 2], + use_dmd=True, + guidance_scale=1.0, + is_amplify_first_chunk=True, +).frames[0] +export_to_video(output, "output_t2v.mp4", fps=24) + +# For Image-to-Video +prompt = """ +A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, +illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, +casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes +apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and +relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and +respect for nature’s might. +""" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + image=load_image(image_path).resize((640, 384)), + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[2, 2, 2], + use_dmd=True, + guidance_scale=1.0, + is_amplify_first_chunk=True, +).frames[0] +export_to_video(output, "output_i2v.mp4", fps=24) + +# For Video-to-Video +prompt = """ +A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees +under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, +emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to +the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. +A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. +""" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + video=load_video(video_path), + num_frames=99, + use_dynamic_shifting=True, + is_enable_stage2=True, + stage2_num_stages=3, + stage2_num_inference_steps_list=[2, 2, 2], + use_dmd=True, + guidance_scale=1.0, + is_amplify_first_chunk=True, +).frames[0] +export_to_video(output, "output_v2v.mp4", fps=24) +``` + + + + + +## HeliosPipeline + +[[autodoc]] HeliosPipeline + + - all + - __call__ + +## HeliosPipelineOutput + +[[autodoc]] pipelines.Helios.pipeline_output.HeliosPipelineOutput diff --git a/docs/source/en/api/schedulers/helios_unipc.md b/docs/source/en/api/schedulers/helios_unipc.md new file mode 100644 index 000000000000..97691b42479e --- /dev/null +++ b/docs/source/en/api/schedulers/helios_unipc.md @@ -0,0 +1,20 @@ + + +# HeliosUniPCScheduler + +`HeliosUniPCScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). + +## HeliosUniPCScheduler +[[autodoc]] HeliosUniPCScheduler + +scheduling_helios_unipc diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md new file mode 100644 index 000000000000..93b276c29d2d --- /dev/null +++ b/docs/source/en/using-diffusers/helios.md @@ -0,0 +1,67 @@ + +# Helios + +[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are: + +- Without commonly used anti-drifting strategies (\eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), \ours generates minute-scale videos with high quality and strong coherence. +- Without standard acceleration techniques (\eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), \ours achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU. +- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models. + +This guide will walk you through using Helios for use cases. + +## Load Model Checkpoints + +Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. + +```python +# !pip install Helios_eva_clip insightface facexlib +import torch +from diffusers import HeliosPipeline +from huggingface_hub import snapshot_download + +# For Best Quality +snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Intermediate Weight +snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# For Best Efficiency +snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Text-to-Video Showcases + + +
+ +## Image-to-Video Showcases + + +
+ +## Interactive-Video Showcases + + +
+ +## Resources + +Learn more about Helios with the following resources. +- A [video](https://www.youtube.com/watch?v=) demonstrating Helios's main features. +- The research paper, [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) for more details. diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 337d010fc74d..ab9eaf6ec7fb 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -132,6 +132,8 @@ sections: - local: using-diffusers/consisid title: ConsisID + - local: using-diffusers/helios + title: Helios - title: Resources isExpanded: false diff --git a/docs/source/zh/community_projects.md b/docs/source/zh/community_projects.md index 0440142452f1..ffa45f1e9bb0 100644 --- a/docs/source/zh/community_projects.md +++ b/docs/source/zh/community_projects.md @@ -26,6 +26,14 @@ http://www.apache.org/licenses/LICENSE-2.0 项目名称 描述 + + helios + Helios:比1.3B更低开销、更快且更强的14B的实时长视频生成模型 + + + consisid + ConsisID:零样本身份保持的文本到视频生成模型 + dream-textures Stable Diffusion内置到Blender diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md new file mode 100644 index 000000000000..17447ed809eb --- /dev/null +++ b/docs/source/zh/using-diffusers/helios.md @@ -0,0 +1,68 @@ + +# Helios + +[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时,拥有媲美强大基线模型的生成质量,并在统一架构下原生集成了文生视频(T2V)、图生视频(I2V)和视频生视频(V2V)任务。Helios 的主要特性包括: + +- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样),我们的模型即可生成高质量且高度连贯的分钟级视频。 +- 无需标准的加速技术(例如:KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化),作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。 +- 引入了多项优化方案,在降低显存消耗的同时,显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片(sharding)等基础设施,即可使用与图像模型相当的批大小(batch sizes)来训练 14B 的视频生成模型。 + +本指南将引导您完成 Helios 在不同场景下的使用。 + +## Load Model Checkpoints + +模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。 + +```python +# !pip install Helios_eva_clip insightface facexlib +import torch +from diffusers import HeliosPipeline +from huggingface_hub import snapshot_download + +# For Best Quality +snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Intermediate Weight +snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# For Best Efficiency +snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") +pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Text-to-Video Showcases + + +
+ +## Image-to-Video Showcases + + +
+ +## Interactive-Video Showcases + + +
+ +## Resources + +通过以下资源了解有关 Helios 的更多信息: + +- 一段 [视频](https://www.youtube.com/watch?v=) 演示了 Helios 的主要功能; +- 有关更多详细信息,请参阅研究论文 [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/)。 diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1fc0914fe09e..a8c5587bc531 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -227,6 +227,7 @@ "FluxMultiControlNetModel", "FluxTransformer2DModel", "GlmImageTransformer2DModel", + "HeliosTransformer3DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -359,6 +360,7 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", + "HeliosUniPCScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -515,6 +517,7 @@ "FluxPipeline", "FluxPriorReduxPipeline", "GlmImagePipeline", + "HeliosPipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -994,6 +997,7 @@ FluxMultiControlNetModel, FluxTransformer2DModel, GlmImageTransformer2DModel, + HeliosTransformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -1122,6 +1126,7 @@ FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, + HeliosUniPCScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, @@ -1257,6 +1262,7 @@ FluxPipeline, FluxPriorReduxPipeline, GlmImagePipeline, + HeliosPipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index bdd4dbbcd4b5..ed0d2a07336f 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -78,6 +78,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "HeliosLoraLoaderMixin", "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", @@ -118,6 +119,7 @@ def text_encoder_attn_modules(text_encoder): CogView4LoraLoaderMixin, Flux2LoraLoaderMixin, FluxLoraLoaderMixin, + HeliosLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, KandinskyLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3423a88d3368..4f94b094c8f5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4321,6 +4321,280 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class HeliosLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + @classmethod + def _maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + if transformer.config.image_dim is None: + return state_dict + + target_device = transformer.device + + if any(k.startswith("transformer.blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + has_bias = any(".lora_B.bias" in k for k in state_dict) + + if is_i2v_lora: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + # These keys should exist if the block `i` was part of the T2V LoRA. + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" + ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" + + if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: + continue + + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device + ) + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device + ) + + # If the original LoRA had biases (indicated by has_bias) + # AND the specific reference bias key exists for this block. + + ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" + if has_bias and ref_key_lora_B_bias in state_dict: + ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( + ref_lora_B_bias_tensor, + device=target_device, + ) + + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 80fb6a72869a..a96542c2a50c 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -51,6 +51,7 @@ "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, + "HeliosTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 96953afa4f4a..450e69056a72 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -100,6 +100,7 @@ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] + _import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d9d1b27a1e40..45157ee91808 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -28,6 +28,7 @@ from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel from .transformer_glm_image import GlmImageTransformer2DModel + from .transformer_helios import HeliosTransformer3DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py new file mode 100644 index 000000000000..fb1e0f15cf87 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -0,0 +1,943 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import lru_cache +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import apply_lora_scale, deprecate, logging +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def center_down_sample_3d(x, kernel_size): + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def apply_rotary_emb_transposed( + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, +): + x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + out = torch.empty_like(hidden_states) + out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2] + out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2] + return out.type_as(hidden_states) + + +def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if not attn.is_cross_attention: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +class HeliosAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "HeliosAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + original_context_length: int = None, + ) -> torch.Tensor: + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + query = apply_rotary_emb_transposed(query, rotary_emb) + key = apply_rotary_emb_transposed(key, rotary_emb) + + if not attn.is_cross_attention and attn.is_amplify_history: + history_seq_len = hidden_states.shape[1] - original_context_length + + if history_seq_len > 0: + scale_key = attn.get_scale_key() + if attn.history_scale_mode == "per_head": + scale_key = scale_key.view(1, 1, -1, 1) + key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class HeliosAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "The HeliosAttnProcessor2_0 class is deprecated and will be removed in a future version. " + "Please use HeliosAttnProcessor instead. " + ) + deprecate("HeliosAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + return HeliosAttnProcessor(*args, **kwargs) + + +class HeliosAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = HeliosAttnProcessor + _available_processors = [HeliosAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + is_amplify_history=False, + history_scale_mode="per_head", # [scalar, per_head] + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + if is_cross_attention is not None: + self.is_cross_attention = is_cross_attention + else: + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + self.is_amplify_history = is_amplify_history + if is_amplify_history: + if history_scale_mode == "scalar": + self.history_key_scale = nn.Parameter(torch.ones(1)) + elif history_scale_mode == "per_head": + self.history_key_scale = nn.Parameter(torch.ones(heads)) + else: + raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}") + self.history_scale_mode = history_scale_mode + self.max_scale = 10.0 + self.register_buffer("_scale_cache", None) + + def get_scale_key(self): + if self.history_key_scale.requires_grad: + scale = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0) + else: + if self._scale_cache is None: + self._scale_cache = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0) + scale = self._scale_cache + return scale + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if not self.is_cross_attention: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + original_context_length: int = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states, + attention_mask, + rotary_emb, + original_context_length, + **kwargs, + ) + + +class HeliosTimeTextEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + is_return_encoder_hidden_states: bool = True, + ): + B = None + F = None + if timestep.ndim == 2: + B, F = timestep.shape + timestep = timestep.flatten() + + timestep = self.timesteps_proj(timestep) # torch.Size([2]) -> torch.Size([2, 256]) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) # torch.Size([2, 1536]) + timestep_proj = self.time_proj(self.act_fn(temb)) # torch.Size([2, 9216] + + if B is not None and F is not None: + temb = temb.reshape(B, F, -1) + timestep_proj = timestep_proj.reshape(B, F, -1) + + if encoder_hidden_states is not None and is_return_encoder_hidden_states: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) # torch.Size([2, 512, 1536]) + + return temb, timestep_proj, encoder_hidden_states + + +class HeliosMultiTermMemoryPatch(nn.Module): + def __init__(self, in_channels, inner_dim): + super().__init__() + self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.patch_mid = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.patch_long = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + def forward( + self, + hidden_states, + rope_freqs, + rope_fn, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + ): + # Process clean latents (1x) + if latents_history_short is not None and indices_latents_history_short is not None: + latents_history_short = latents_history_short.to(hidden_states) + latents_history_short = self.patch_short(latents_history_short) + _, _, _, H1, W1 = latents_history_short.shape + latents_history_short = latents_history_short.flatten(2).transpose(1, 2) + + rope_freqs_history_short = rope_fn( + frame_indices=indices_latents_history_short, + height=H1, + width=W1, + device=latents_history_short.device, + ) + rope_freqs_history_short = rope_freqs_history_short.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) + rope_freqs = torch.cat([rope_freqs_history_short, rope_freqs], dim=1) + + # Process 2x history latents + if latents_history_mid is not None and indices_latents_history_mid is not None: + latents_history_mid = latents_history_mid.to(hidden_states) + latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) + latents_history_mid = self.patch_mid(latents_history_mid) + latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) + + rope_freqs_history_mid = rope_fn( + frame_indices=indices_latents_history_mid, + height=H1, + width=W1, + device=latents_history_mid.device, + ) + rope_freqs_history_mid = pad_for_3d_conv(rope_freqs_history_mid, (2, 2, 2)) + rope_freqs_history_mid = center_down_sample_3d(rope_freqs_history_mid, (2, 2, 2)) + rope_freqs_history_mid = rope_freqs_history_mid.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) + rope_freqs = torch.cat([rope_freqs_history_mid, rope_freqs], dim=1) + + # Process 4x history latents + if latents_history_long is not None and indices_latents_history_long is not None: + latents_history_long = latents_history_long.to(hidden_states) + latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) + latents_history_long = self.patch_long(latents_history_long) + latents_history_long = latents_history_long.flatten(2).transpose(1, 2) + + rope_freqs_history_long = rope_fn( + frame_indices=indices_latents_history_long, + height=H1, + width=W1, + device=latents_history_long.device, + ) + rope_freqs_history_long = pad_for_3d_conv(rope_freqs_history_long, (4, 4, 4)) + rope_freqs_history_long = center_down_sample_3d(rope_freqs_history_long, (4, 4, 4)) + rope_freqs_history_long = rope_freqs_history_long.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) + rope_freqs = torch.cat([rope_freqs_history_long, rope_freqs], dim=1) + + return hidden_states, rope_freqs + + +class HeliosRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False) + self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False) + self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False) + + def _get_freqs_base(self, dim): + return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + + @torch.no_grad() + def get_frequency_batched(self, freqs_base, pos): + freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos) + freqs = freqs.repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + @lru_cache(maxsize=32) + def _get_spatial_meshgrid(self, height, width, device_str): + device = torch.device(device_str) + gy = torch.arange(height, device=device, dtype=torch.float32) + gx = torch.arange(width, device=device, dtype=torch.float32) + GY, GX = torch.meshgrid(gy, gx, indexing="ij") + return GY, GX + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + B = frame_indices.shape[0] + T = frame_indices.shape[1] + + frame_indices = frame_indices.to(device=device, dtype=torch.float32) + GY, GX = self._get_spatial_meshgrid(height, width, str(device)) + + GT = frame_indices[:, :, None, None].expand(B, T, height, width) + GY_batch = GY[None, None, :, :].expand(B, T, -1, -1) + GX_batch = GX[None, None, :, :].expand(B, T, -1, -1) + + FCT, FST = self.get_frequency_batched(self.freqs_base_t, GT) + FCY, FSY = self.get_frequency_batched(self.freqs_base_y, GY_batch) + FCX, FSX = self.get_frequency_batched(self.freqs_base_x, GX_batch) + + result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + + return result.permute(1, 0, 2, 3, 4) + + +@maybe_allow_in_graph +class HeliosTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + guidance_cross_attn: bool = False, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head], + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=HeliosAttnProcessor(), + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + + # 2. Cross-attention + self.attn2 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=HeliosAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # 4. Guidance cross-attention + self.guidance_cross_attn = guidance_cross_attn + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + original_context_length: int = None, + ) -> torch.Tensor: + if temb.ndim == 4: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1( + norm_hidden_states, + None, + None, + rotary_emb, + original_context_length, + ) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + if self.guidance_cross_attn: + history_seq_len = hidden_states.shape[1] - original_context_length + + history_hidden_states, hidden_states = ( + hidden_states[:, :history_seq_len], + hidden_states[:, history_seq_len:], + ) + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1) + else: + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class HeliosTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Helios model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["HeliosTransformerBlock"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "history_key_scale", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["HeliosTransformerBlock"] + _cp_plan = { + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False), + "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + guidance_cross_attn: bool = False, + zero_history_timestep: bool = False, + has_multi_term_memory_patch: bool = False, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head] + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = HeliosRotaryPosEmbed(rope_dim=(44, 42, 42), theta=10000.0) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Initial Multi Term Memory Patch + self.zero_history_timestep = zero_history_timestep + self.inner_dim = inner_dim + if has_multi_term_memory_patch: + self.multi_term_memory_patch = HeliosMultiTermMemoryPatch(in_channels, self.inner_dim) + + # 3. Condition embeddings + self.condition_embedder = HeliosTimeTextEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + ) + + # 4. Transformer blocks + self.blocks = nn.ModuleList( + [ + HeliosTransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + qk_norm, + cross_attn_norm, + eps, + added_kv_proj_dim, + guidance_cross_attn=guidance_cross_attn, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + for _ in range(num_layers) + ] + ) + + # 5. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def process_input_hidden_states( + self, + latents, + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + ): + height_list = [] + width_list = [] + temporal_list = [] + seq_list = [] + + hidden_states = self.patch_embedding(latents) + B, C, T, H, W = hidden_states.shape + + if indices_hidden_states is None: + indices_hidden_states = torch.arange(0, T).unsqueeze(0).expand(B, -1) + + hidden_states = hidden_states.flatten(2).transpose( + 1, 2 + ) # torch.Size([1, 3072, 9, 44, 34]) -> torch.Size([1, 13464, 3072]) + + rope_freqs = self.rope( + frame_indices=indices_hidden_states, + height=H, + width=W, + device=hidden_states.device, + ) # torch.Size([1, 9]) -> torch.Size([1, 256, 9, 44, 34]) + rope_freqs = rope_freqs.flatten(2).transpose(1, 2) # torch.Size([1, 13464, 256]) + + height_list.append(H) + width_list.append(W) + temporal_list.append(T) + seq_list.append(hidden_states.shape[1]) + + if latents_history_short is not None: + hidden_states, rope_freqs = self.multi_term_memory_patch( + hidden_states=hidden_states, + rope_freqs=rope_freqs, + rope_fn=self.rope, + latents_history_short=latents_history_short, + indices_latents_history_short=indices_latents_history_short, + latents_history_mid=latents_history_mid, + indices_latents_history_mid=indices_latents_history_mid, + latents_history_long=latents_history_long, + indices_latents_history_long=indices_latents_history_long, + ) + + return ( + hidden_states, + rope_freqs, + height_list, + width_list, + temporal_list, + seq_list, + ) + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + # ------------ Stage 1 ------------ + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + assert ( + len( + { + x is None + for x in [ + indices_hidden_states, + indices_latents_history_short, + indices_latents_history_mid, + indices_latents_history_long, + latents_history_short, + latents_history_mid, + latents_history_long, + ] + } + ) + == 1 + ), "All history latents and indices must either all exist or all be None" + + if indices_hidden_states is not None and indices_hidden_states.ndim == 1: + indices_hidden_states = indices_hidden_states.unsqueeze(0) + if indices_latents_history_short is not None and indices_latents_history_short.ndim == 1: + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + if indices_latents_history_mid is not None and indices_latents_history_mid.ndim == 1: + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + if indices_latents_history_long is not None and indices_latents_history_long.ndim == 1: + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.config.patch_size + + ( + hidden_states, + rotary_emb, + post_patch_height_list, + post_patch_width_list, + post_patch_num_frames_list, + original_context_length_list, + ) = self.process_input_hidden_states( + latents=hidden_states, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + ) # hidden: [high, mid, low] -> [low, mid, high] + post_patch_num_frames = sum(post_patch_num_frames_list) + post_patch_height = sum(post_patch_height_list) + post_patch_width = sum(post_patch_width_list) + original_context_length = sum(original_context_length_list) + history_context_length = hidden_states.shape[1] - original_context_length + + if indices_hidden_states is not None and self.zero_history_timestep: + timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device) + temb_t0, timestep_proj_t0, _ = self.condition_embedder( + timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False + ) + temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1) + timestep_proj_t0 = ( + timestep_proj_t0.unflatten(-1, (6, -1)) + .view(1, 6, 1, -1) + .expand(batch_size, -1, history_context_length, -1) + ) + + temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) + + if indices_hidden_states is not None and not self.zero_history_timestep: + main_repeat_size = hidden_states.shape[1] + else: + main_repeat_size = original_context_length + temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1) + timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1) + + if indices_hidden_states is not None and self.zero_history_timestep: + temb = torch.cat([temb_t0, temb], dim=1) + timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2) + + if timestep_proj.ndim == 4: + timestep_proj = timestep_proj.permute(0, 2, 1, 3) + + # 4. Transformer blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + rotary_emb = rotary_emb.contiguous() + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + temb = temb[:, -original_context_length:, :] + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 638598051d64..0002493e97e6 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -237,6 +237,7 @@ "EasyAnimateInpaintPipeline", "EasyAnimateControlPipeline", ] + _import_structure["helios"] = ["HeliosPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ @@ -667,6 +668,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline + from .helios import HeliosPipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 7247ca0d161c..52af1b1c6b46 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -54,6 +54,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline +from .helios import HeliosPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -174,6 +175,7 @@ ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), ("glm_image", GlmImagePipeline), + ("helios", HeliosPipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/helios/__init__.py b/src/diffusers/pipelines/helios/__init__.py new file mode 100644 index 000000000000..a540d1462389 --- /dev/null +++ b/src/diffusers/pipelines/helios/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_helios"] = ["HeliosPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_helios import HeliosPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py new file mode 100644 index 000000000000..c59c809f4ef8 --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -0,0 +1,1519 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +from enum import Enum +from itertools import accumulate +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +import torch.nn.functional as F +from einops import rearrange +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosUniPCScheduler, UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +@torch.amp.autocast("cuda", dtype=torch.float32) +def optimized_scale(positive_flat, negative_flat): + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, + exp_max=7.0, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def apply_schedule_shift( + sigmas, + noise, + sigmas_two=None, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, + exp_max: float = 7.0, + is_exponential: bool = False, + mu: float = None, + return_mu: bool = False, +): + if mu is None: + # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper + image_seq_len = (noise.shape[-1] * noise.shape[-2] * noise.shape[-3]) // 4 # patch size 1,2,2 + mu = calculate_shift( + image_seq_len, + base_seq_len if base_seq_len is not None else 256, + max_seq_len if max_seq_len is not None else 4096, + base_shift if base_shift is not None else 0.5, + max_shift if max_shift is not None else 1.15, + exp_max if exp_max is not None else 7.0, + ) + if is_exponential: + mu = min(mu, math.log(exp_max)) + mu = math.exp(mu) + + if sigmas_two is not None: + sigmas = (sigmas * mu) / (1 + (mu - 1) * sigmas) + sigmas_two = (sigmas_two * mu) / (1 + (mu - 1) * sigmas_two) + if return_mu: + return sigmas, sigmas_two, mu + else: + return sigmas, sigmas_two + else: + sigmas = (sigmas * mu) / (1 + (mu - 1) * sigmas) + if return_mu: + return sigmas, mu + else: + return sigmas + + +def add_noise(original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + +def convert_flow_pred_to_x0(flow_pred, xt, timestep, sigmas, timesteps): + # use higher precision for calculations + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + +class VAEDecodeType(str, Enum): + DEFAULT = "default" + DEFAULT_BATCH = "default_batch" + + +class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: Union[UniPCMultistepScheduler, HeliosUniPCScheduler], + transformer: HeliosTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + negative_prompt_attention_mask = None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + fake_latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + fake_video = image.repeat(1, 1, 33, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + latent_window_size: int, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (latent_window_size - 1) * 4 + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (latent_window_size - 1) * 4 + 1 = ({latent_window_size} - 1) * 4 + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks - 1, -1, -1): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.insert(0, chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + def interpolate_prompt_embeds( + self, + prompt_embeds_1: torch.Tensor, + prompt_embeds_2: torch.Tensor, + interpolation_steps: int = 4, + ): + x = torch.lerp( + prompt_embeds_1, + prompt_embeds_2, + torch.linspace(0, 1, steps=interpolation_steps).unsqueeze(1).unsqueeze(2).to(prompt_embeds_1), + ) + interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) + return interpolated_prompt_embeds + + def sample_block_noise(self, batch_size, channel, num_frames, height, width): + gamma = self.scheduler.config.gamma + cov = torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma + dist = torch.distributions.MultivariateNormal(torch.zeros(4, device=cov.device), covariance_matrix=cov) + block_number = batch_size * channel * num_frames * (height // 2) * (width // 2) + + noise = dist.sample((block_number,)) # [block number, 4] + noise = noise.view(batch_size, channel, num_frames, height // 2, width // 2, 2, 2) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise + + def stage1_sample( + self, + latents: torch.Tensor = None, + prompt_embeds: torch.Tensor = None, + negative_prompt_embeds: torch.Tensor = None, + timesteps: torch.Tensor = None, + guidance_scale: Optional[float] = 5.0, + indices_hidden_states: torch.Tensor = None, + indices_latents_history_short: torch.Tensor = None, + indices_latents_history_mid: torch.Tensor = None, + indices_latents_history_long: torch.Tensor = None, + latents_history_short: torch.Tensor = None, + latents_history_mid: torch.Tensor = None, + latents_history_long: torch.Tensor = None, + attention_kwargs: Optional[dict] = None, + device: Optional[torch.device] = None, + transformer_dtype: torch.dtype = None, + scheduler_type: str = "unipc", + use_dynamic_shifting: bool = False, + generator: Optional[torch.Generator] = None, + # ------------ CFG Zero ------------ + use_cfg_zero_star: Optional[bool] = False, + use_zero_init: Optional[bool] = True, + zero_steps: Optional[int] = 1, + # -------------- DMD -------------- + use_dmd: bool = False, + dmd_sigmas: torch.Tensor = None, + dmd_timesteps: torch.Tensor = None, + is_amplify_first_chunk: bool = False, + # ------------ Callback ------------ + callback_on_step_end: Optional[callable] = None, + callback_on_step_end_tensor_inputs: list = None, + progress_bar=None, + ): + batch_size = latents.shape[0] + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + latent_model_input = latents.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance and not use_dmd: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if use_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) + alpha = alpha.to(noise_pred_text.dtype) + + if (i <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * (noise_pred_text - noise_uncond * alpha) + else: + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + if use_dmd: + pred_image_or_video = convert_flow_pred_to_x0( + flow_pred=noise_pred, + xt=latent_model_input, + timestep=t * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if i < len(timesteps) - 1: + latents = add_noise( + pred_image_or_video, + randn_tensor(pred_image_or_video.shape, generator=generator, device=device), + timesteps[i + 1] * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + else: + latents = pred_image_or_video + else: + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + return latents + + def stage2_sample( + self, + latents: torch.Tensor = None, + stage2_num_stages: int = None, + stage2_num_inference_steps_list: List[int] = None, + prompt_embeds: torch.Tensor = None, + negative_prompt_embeds: torch.Tensor = None, + guidance_scale: Optional[float] = 5.0, + indices_hidden_states: torch.Tensor = None, + indices_latents_history_short: torch.Tensor = None, + indices_latents_history_mid: torch.Tensor = None, + indices_latents_history_long: torch.Tensor = None, + latents_history_short: torch.Tensor = None, + latents_history_mid: torch.Tensor = None, + latents_history_long: torch.Tensor = None, + attention_kwargs: Optional[dict] = None, + device: Optional[torch.device] = None, + transformer_dtype: torch.dtype = None, + scheduler_type: str = "unipc", # unipc, euler + use_dynamic_shifting: bool = False, + # ------------ CFG Zero ------------ + use_cfg_zero_star: Optional[bool] = False, + use_zero_init: Optional[bool] = True, + zero_steps: Optional[int] = 1, + # -------------- DMD -------------- + use_dmd: bool = False, + is_amplify_first_chunk: bool = False, + # ------------ Callback ------------ + callback_on_step_end: Optional[callable] = None, + callback_on_step_end_tensor_inputs: list = None, + progress_bar=None, + ): + num_frmaes, height, width = ( + latents.shape[-3], + latents.shape[-2], + latents.shape[-1], + ) + latents = rearrange(latents, "b c t h w -> (b t) c h w") + for _ in range(stage2_num_stages - 1): + height //= 2 + width //= 2 + latents = ( + F.interpolate( + latents, + size=(height, width), + mode="bilinear", + ) + * 2 + ) + latents = rearrange(latents, "(b t) c h w -> b c t h w", t=num_frmaes) + + batch_size = latents.shape[0] + if use_dmd: + start_point_list = [latents] + + i = 0 + for i_s in range(stage2_num_stages): + if use_dmd: + if is_amplify_first_chunk: + self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s] * 2 + 1, i_s, device=device) + else: + self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s] + 1, i_s, device=device) + self.scheduler.timesteps = self.scheduler.timesteps[:-1] + self.scheduler.sigmas = torch.cat([self.scheduler.sigmas[:-2], self.scheduler.sigmas[-1:]]) + else: + self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s], i_s, device=device) + + if i_s > 0: + height *= 2 + width *= 2 + num_frames = latents.shape[2] + latents = rearrange(latents, "b c t h w -> (b t) c h w") + latents = F.interpolate(latents, size=(height, width), mode="nearest") + latents = rearrange(latents, "(b t) c h w -> b c t h w", t=num_frames) + # Fix the stage + ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal + gamma = self.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, channel, num_frames, height, width = latents.shape + noise = self.sample_block_noise(batch_size, channel, num_frames, height, width) + noise = noise.to(device=device, dtype=transformer_dtype) + latents = alpha * latents + beta * noise # To fix the block artifact + + if use_dmd: + start_point_list.append(latents) + + if use_dynamic_shifting: + temp_sigmas = apply_schedule_shift( + self.scheduler.sigmas, + latents, + base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), + base_shift=self.scheduler.config.get("base_shift", 0.5), + max_shift=self.scheduler.config.get("max_shift", 1.15), + ) + temp_timesteps = self.scheduler.timesteps_per_stage[i_s].min() + temp_sigmas[:-1] * ( + self.scheduler.timesteps_per_stage[i_s].max() - self.scheduler.timesteps_per_stage[i_s].min() + ) + + self.scheduler.sigmas = temp_sigmas + self.scheduler.timesteps = temp_timesteps + + timesteps = self.scheduler.timesteps + + for idx, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("cond_uncond"): + noise_uncond = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + )[0] + + if use_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) + alpha = alpha.to(noise_pred_text.dtype) + + if (i_s == 0 and idx <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * ( + noise_pred_text - noise_uncond * alpha + ) + else: + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + if use_dmd: + pred_image_or_video = convert_flow_pred_to_x0( + flow_pred=noise_pred, + xt=latents, + timestep=timestep, + sigmas=self.scheduler.sigmas, + timesteps=self.scheduler.timesteps, + ) + if idx < len(timesteps) - 1: + latents = add_noise( + pred_image_or_video, + start_point_list[i_s], + timesteps[idx + 1] * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), + sigmas=self.scheduler.sigmas, + timesteps=self.scheduler.timesteps, + ) + else: + latents = pred_image_or_video + else: + if scheduler_type == "unipc": + latents = self.scheduler.step_unipc(noise_pred.float(), t, latents, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + i += 1 + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 384, + width: int = 640, + num_frames: int = 73, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: Optional[PipelineImageInput] = None, + image_latents: Optional[torch.Tensor] = None, + fake_image_latents: Optional[torch.Tensor] = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: Optional[PipelineImageInput] = None, + video_latents: Optional[torch.Tensor] = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Interactive ------------ + use_interpolate_prompt: bool = False, + interpolate_time_list: list = [7, 7, 7], + interpolation_steps: int = 3, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + latent_window_size: int = 9, + use_dynamic_shifting: bool = False, + is_keep_x0: bool = True, + # ------------ Stage 2 ------------ + is_enable_stage2: bool = False, + stage2_num_stages: int = 3, + stage2_num_inference_steps_list: list = [10, 10, 10], + scheduler_type: str = "unipc", # unipc, euler + # ------------ CFG Zero ------------ + use_cfg_zero_star: Optional[bool] = False, + use_zero_init: Optional[bool] = True, + zero_steps: Optional[int] = 1, + # ------------ DMD ------------ + use_dmd: bool = False, + is_skip_first_section: bool = False, + is_amplify_first_chunk: bool = False, + # ------------ other ------------ + vae_decode_type: VAEDecodeType = "default", # "default", "default_batch" + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + if use_interpolate_prompt: + assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" + assert isinstance(prompt, list), "prompt must be a list" + assert len(prompt) == len(interpolate_time_list), ( + f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" + ) + assert min(interpolate_time_list) > interpolation_steps, ( + f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" + ) + interpolate_interval_idx = None + interpolate_embeds = None + interpolate_cumulative_list = list(accumulate(interpolate_time_list)) + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(self.vae.device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + self.vae.device, self.vae.dtype + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + # 2. Define call parameters + if use_interpolate_prompt or (prompt is not None and isinstance(prompt, str)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + all_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + ) + + transformer_dtype = self.transformer.dtype + all_prompt_embeds = all_prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + if use_interpolate_prompt: + negative_prompt_embeds = negative_prompt_embeds[0].unsqueeze(0) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + latent_window_size=latent_window_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // latent_window_size + for i in range(num_latent_chunks): + chunk_start = i * latent_window_size + chunk_end = chunk_start + latent_window_size + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + history_video = None + total_generated_latent_frames = 0 + + if not is_keep_x0: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents, fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + # 6. Denoising loop + if use_interpolate_prompt: + if num_latent_sections < max(interpolate_cumulative_list): + num_latent_sections = sum(interpolate_cumulative_list) + print(f"Update num_latent_sections to: {num_latent_sections}") + + for k in range(num_latent_sections): + if use_interpolate_prompt: + assert num_latent_sections >= max(interpolate_cumulative_list) + + current_interval_idx = 0 + for idx, cumulative_val in enumerate(interpolate_cumulative_list): + if k < cumulative_val: + current_interval_idx = idx + break + + if current_interval_idx == 0: + prompt_embeds = all_prompt_embeds[0].unsqueeze(0) + else: + interval_start = interpolate_cumulative_list[current_interval_idx - 1] + position_in_interval = k - interval_start + + if position_in_interval < interpolation_steps: + if interpolate_embeds is None or interpolate_interval_idx != current_interval_idx: + interpolate_embeds = self.interpolate_prompt_embeds( + prompt_embeds_1=all_prompt_embeds[current_interval_idx - 1].unsqueeze(0), + prompt_embeds_2=all_prompt_embeds[current_interval_idx].unsqueeze(0), + interpolation_steps=interpolation_steps, + ) + interpolate_interval_idx = current_interval_idx + + prompt_embeds = interpolate_embeds[position_in_interval] + else: + prompt_embeds = all_prompt_embeds[current_interval_idx].unsqueeze(0) + else: + prompt_embeds = all_prompt_embeds + + is_first_section = k == 0 + is_second_section = k == 1 + if is_keep_x0: + if is_first_section: + history_sizes_first_section = [1] + history_sizes.copy() + history_latents_first_section = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes_first_section), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents_first_section = torch.cat( + [history_latents_first_section, fake_image_latents], dim=2 + ) + if video_latents is not None: + history_frames = history_latents_first_section.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents_first_section = torch.cat( + [history_latents_first_section[:, :, :keep_frames, :, :], video_latents], dim=2 + ) + else: + history_latents_first_section = video_latents + + indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix, latents_history_long, latents_history_mid, latents_history_1x = ( + history_latents_first_section[:, :, -sum(history_sizes_first_section) :].split( + history_sizes_first_section, dim=2 + ) + ) + if image_latents is not None: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + indices = torch.arange(0, sum([*history_sizes, latent_window_size])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, latent_window_size], dim=0) + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + if not is_enable_stage2: + self.scheduler.set_timesteps(num_inference_steps, device=device) + + if use_dynamic_shifting: + sigmas = torch.linspace( + 0.999, 0.0, steps=num_inference_steps + 1, dtype=torch.float32, device=device + )[:-1] + sigmas = apply_schedule_shift( + sigmas=sigmas, + noise=latents, + base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), + base_shift=self.scheduler.config.get("base_shift", 0.5), + max_shift=self.scheduler.config.get("max_shift", 1.15), + ) + timesteps = sigmas * 1000.0 # rescale to [0, 1000.0) + timesteps = timesteps.to(device) + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.scheduler.timesteps = timesteps + self.scheduler.sigmas = sigmas + + timesteps = self.scheduler.timesteps + + dmd_sigmas = None + dmd_timesteps = None + if use_dmd: + dmd_sigmas = self.scheduler.sigmas.to(self.transformer.device) + dmd_timesteps = self.scheduler.timesteps.to(self.transformer.device) + + self._num_timesteps = len(timesteps) + else: + num_inference_steps = ( + sum(stage2_num_inference_steps_list) * 2 + if is_amplify_first_chunk and use_dmd and is_first_section + else sum(stage2_num_inference_steps_list) + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + if is_enable_stage2: + latents = self.stage2_sample( + latents=latents, + stage2_num_stages=stage2_num_stages, + stage2_num_inference_steps_list=stage2_num_inference_steps_list, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + device=device, + transformer_dtype=transformer_dtype, + scheduler_type=scheduler_type, + use_dynamic_shifting=use_dynamic_shifting, + # ------------ CFG Zero ------------ + use_cfg_zero_star=use_cfg_zero_star, + use_zero_init=use_zero_init, + zero_steps=zero_steps, + # -------------- DMD -------------- + use_dmd=use_dmd, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_section, + # ------------ Callback ------------ + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + progress_bar=progress_bar, + ) + else: + latents = self.stage1_sample( + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + timesteps=timesteps, + guidance_scale=guidance_scale, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + device=device, + transformer_dtype=transformer_dtype, + scheduler_type=scheduler_type, + use_dynamic_shifting=use_dynamic_shifting, + generator=generator, + # ------------ CFG Zero ------------ + use_cfg_zero_star=use_cfg_zero_star, + use_zero_init=use_zero_init, + zero_steps=zero_steps, + # -------------- DMD -------------- + use_dmd=use_dmd, + dmd_sigmas=dmd_sigmas, + dmd_timesteps=dmd_timesteps, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_section, + # ------------ Callback ------------ + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + progress_bar=progress_bar, + ) + + if is_keep_x0 and ( + (is_first_section and image_latents is None) or (is_skip_first_section and is_second_section) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + index_slice = ( + slice(None), + slice(None), + slice(-latent_window_size, None), + ) + + if vae_decode_type == "default": + current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + if vae_decode_type == "default_batch": + total_latent_frames = real_history_latents.shape[2] + batch_size = real_history_latents.shape[0] + num_chunks = total_latent_frames // latent_window_size + + chunks = ( + real_history_latents.reshape( + batch_size, + -1, + num_chunks, + latent_window_size, + real_history_latents.shape[-2], + real_history_latents.shape[-1], + ) + .permute(0, 2, 1, 3, 4, 5) + .reshape( + batch_size * num_chunks, + -1, + latent_window_size, + real_history_latents.shape[-2], + real_history_latents.shape[-1], + ) + ) + + chunks = chunks.to(vae_dtype) / latents_std + latents_mean + batch_video = self.vae.decode(chunks, return_dict=False)[0] + + video_frames_per_chunk = batch_video.shape[2] + history_video = ( + batch_video.reshape( + batch_size, + num_chunks, + -1, + video_frames_per_chunk, + batch_video.shape[-2], + batch_video.shape[-1], + ) + .permute(0, 2, 1, 3, 4, 5) + .reshape( + batch_size, + -1, + num_chunks * video_frames_per_chunk, + batch_video.shape[-2], + batch_video.shape[-1], + ) + ) + + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/helios/pipeline_output.py b/src/diffusers/pipelines/helios/pipeline_output.py new file mode 100644 index 000000000000..08546289ef4c --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HeliosPipelineOutput(BaseOutput): + r""" + Output class for Helios pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 4199e75bf331..add9650488d4 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,6 +61,7 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] + _import_structure["scheduling_helios_unipc"] = ["HeliosUniPCScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -164,6 +165,7 @@ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler + from .scheduling_helios_unipc import HeliosUniPCScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_helios_unipc.py b/src/diffusers/schedulers/scheduling_helios_unipc.py new file mode 100644 index 000000000000..1ef41cdd6c0d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_helios_unipc.py @@ -0,0 +1,817 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, deprecate + + +@dataclass +class HeliosUniPCSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor + last_sample: torch.FloatTensor + this_order: int + + +class HeliosUniPCScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, # Following Stable diffusion 3, + stages: int = 3, + stage_range: List = [0, 1 / 3, 2 / 3, 1], + gamma: float = 1 / 3, + # For UniPC + thresholding: bool = False, + prediction_type: str = "flow_prediction", + solver_order: int = 2, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + use_flow_sigmas: bool = True, + version: str = "v1", + ): + self.version = version + self.timestep_ratios = {} # The timestep ratio for each stage + self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) + self.sigmas_per_stage = {} # always uniform [1000, 0] + self.start_sigmas = {} # for start point / upsample renoise + self.end_sigmas = {} # for end point + self.ori_start_sigmas = {} + + # self.init_sigmas() + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def init_sigmas(self): + """ + initialize the global timesteps and sigmas + """ + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def init_sigmas_for_each_stage(self): + """ + Init the timesteps for each stage + """ + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + # Init the start and end point of each stage + for i_s in range(stages): + # To decide the start and ends point + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma + # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + if self.version == "v2": + new_start_indice = ( + len(self.sigmas) - torch.searchsorted(self.sigmas.flip(0), start_sigma, right=True) + ).item() + self.sigmas_per_stage[i_s] = self.sigmas[new_start_indice:end_indice] + self.timesteps_per_stage[i_s] = self.timesteps[new_start_indice:end_indice] + + if self.version == "v2": + return + + # Determine the ratio of each stage according to flow length + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + end_ratio = 0.9999999999999999 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + # Determine the timesteps and sigmas for each stage + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] + timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) + timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) + ) + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int, + device: Union[str, torch.device] = None, + ): + """ + Setting the timesteps and sigmas for each stage + """ + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.version == "v1": + stage_timesteps = self.timesteps_per_stage[stage_index] + timestep_max = stage_timesteps[0].item() + timestep_min = stage_timesteps[-1].item() + + timesteps = np.linspace( + timestep_max, + timestep_min, + num_inference_steps, + ) + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + stage_sigmas = self.sigmas_per_stage[stage_index] + sigma_max = stage_sigmas[0].item() + sigma_min = stage_sigmas[-1].item() + + ratios = np.linspace(sigma_max, sigma_min, num_inference_steps) + sigmas = torch.from_numpy(ratios).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + else: + total_steps = len(self.timesteps_per_stage[stage_index]) + indices = np.linspace(0, total_steps - 1, num_inference_steps, dtype=int) + + self.timesteps = self.timesteps_per_stage[stage_index][indices].to(device=device) + + if stage_index == (self.config.stages - 1): + sigmas = self.sigmas_per_stage[stage_index][indices].to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + else: + sigmas = self.sigmas_per_stage[stage_index][indices].to(device=device) + self.sigmas = torch.cat( + [sigmas, torch.tensor([self.ori_start_sigmas[stage_index + 1]], device=sigmas.device)] + ) + + self._step_index = None + self.reset_scheduler_history() + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor] = None, + sample: torch.FloatTensor = None, + generator: Optional[torch.Generator] = None, + sigma: Optional[torch.FloatTensor] = None, + sigma_next: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[HeliosUniPCSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None" + + if sigma is None and sigma_next is None: + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._step_index = 0 + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if sigma is None and sigma_next is None: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return HeliosUniPCSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- UniPC ---------------------------------- + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = torch.clamp(sigma, min=1e-8) + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + sigma: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + flag = False + if sigma is None: + flag = True + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + if flag: + sigma_t = self.sigmas[self.step_index] + else: + sigma_t = sigma + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + sigma: torch.Tensor = None, + sigma_next: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + if sigma_next is None and sigma is None: + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + else: + sigma_t, sigma_s0 = sigma_next, sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + sigma_before: torch.Tensor = None, + sigma: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + if sigma_before is None and sigma is None: + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + else: + sigma_t, sigma_s0 = sigma, sigma_before + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step_unipc( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor] = None, + sample: torch.Tensor = None, + return_dict: bool = True, + model_outputs: list = None, + timestep_list: list = None, + sigma_before: torch.Tensor = None, + sigma: torch.Tensor = None, + sigma_next: torch.Tensor = None, + cus_step_index: int = None, + cus_lower_order_num: int = None, + cus_this_order: int = None, + cus_last_sample: torch.Tensor = None, + ) -> Union[HeliosUniPCSchedulerOutput, Tuple]: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if cus_step_index is None: + if self.step_index is None: + self._step_index = 0 + else: + self._step_index = cus_step_index + + if cus_lower_order_num is not None: + self.lower_order_nums = cus_lower_order_num + + if cus_this_order is not None: + self.this_order = cus_this_order + + if cus_last_sample is not None: + self.last_sample = cus_last_sample + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + # Convert model output using the proper conversion method + model_output_convert = self.convert_model_output(model_output, sample=sample, sigma=sigma) + + if model_outputs is not None and timestep_list is not None: + self.model_outputs = model_outputs[:-1] + self.timestep_list = timestep_list[:-1] + + # print("1", self.step_index, self.timestep_list) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + sigma_before=sigma_before, + sigma=sigma, + ) + + if model_outputs is not None and timestep_list is not None: + model_outputs[-1] = model_output_convert + self.model_outputs = model_outputs[1:] + self.timestep_list = timestep_list[1:] + else: + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + sigma=sigma, + sigma_next=sigma_next, + ) + + if cus_lower_order_num is None: + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + if cus_step_index is None: + self._step_index += 1 + + if not return_dict: + return (prev_sample, model_outputs, self.last_sample, self.this_order) + + return HeliosUniPCSchedulerOutput( + prev_sample=prev_sample, + model_outputs=model_outputs, + last_sample=self.last_sample, + this_order=self.this_order, + ) + + def reset_scheduler_history(self): + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.disable_corrector = self.config.disable_corrector + self.solver_p = self.config.solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4e402921aa5f..6ce4e4438a47 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1151,6 +1151,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] @@ -2488,6 +2503,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosUniPCScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8758c549ca77..51ee5312d2e7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1487,6 +1487,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HeliosPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanVideoFramepackPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/helios/__init__.py b/tests/pipelines/helios/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py new file mode 100644 index 000000000000..86e1ecbf70be --- /dev/null +++ b/tests/pipelines/helios/test_helios.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 58f4603a6fc105af62cc03fa89e6b5f5fc49aa59 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 24 Feb 2026 14:55:45 +0000 Subject: [PATCH 002/107] fix test --- src/diffusers/schedulers/scheduling_helios_unipc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_unipc.py b/src/diffusers/schedulers/scheduling_helios_unipc.py index 1ef41cdd6c0d..28dc5aa76f26 100644 --- a/src/diffusers/schedulers/scheduling_helios_unipc.py +++ b/src/diffusers/schedulers/scheduling_helios_unipc.py @@ -174,7 +174,6 @@ def begin_index(self): """ return self._begin_index - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -332,7 +331,6 @@ def step( return HeliosUniPCSchedulerOutput(prev_sample=prev_sample) # ---------------------------------- UniPC ---------------------------------- - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): if self.config.use_flow_sigmas: alpha_t = 1 - sigma From 280e3d87a683e694b09ddaf874aadf0b96d63fa9 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 24 Feb 2026 15:01:30 +0000 Subject: [PATCH 003/107] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 48 +++++++++---------- .../dummy_torch_and_transformers_objects.py | 20 ++++---- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6ce4e4438a47..55ed0c212d7a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1031,7 +1031,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HiDreamImageTransformer2DModel(metaclass=DummyObject): +class HeliosTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1046,7 +1046,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DControlNetModel(metaclass=DummyObject): +class HiDreamImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1061,7 +1061,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DModel(metaclass=DummyObject): +class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1076,7 +1076,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): +class HunyuanDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1091,7 +1091,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanImageTransformer2DModel(metaclass=DummyObject): +class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1106,7 +1106,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanVideo15Transformer3DModel(metaclass=DummyObject): +class HunyuanImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1121,7 +1121,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): +class HunyuanVideo15Transformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1136,7 +1136,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HunyuanVideoTransformer3DModel(metaclass=DummyObject): +class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -1151,7 +1151,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HeliosTransformer3DModel(metaclass=DummyObject): +class HunyuanVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -2503,21 +2503,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HeliosUniPCScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2773,6 +2758,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosUniPCScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 51ee5312d2e7..a5830c9f0fd7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1352,7 +1352,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HiDreamImagePipeline(metaclass=DummyObject): +class HeliosPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1367,7 +1367,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanDiTControlNetPipeline(metaclass=DummyObject): +class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1382,7 +1382,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanDiTPAGPipeline(metaclass=DummyObject): +class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1397,7 +1397,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanDiTPipeline(metaclass=DummyObject): +class HunyuanDiTPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1412,7 +1412,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanImagePipeline(metaclass=DummyObject): +class HunyuanDiTPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1427,7 +1427,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanImageRefinerPipeline(metaclass=DummyObject): +class HunyuanImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1442,7 +1442,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject): +class HunyuanImageRefinerPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1457,7 +1457,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject): +class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1472,7 +1472,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HunyuanVideo15Pipeline(metaclass=DummyObject): +class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1487,7 +1487,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class HeliosPipeline(metaclass=DummyObject): +class HunyuanVideo15Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 6e0f23b273bb07353ec09d87b21f8a5188adc159 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 24 Feb 2026 15:09:26 +0000 Subject: [PATCH 004/107] change script path --- 0_temp_helios_test/infer_helios.py | 4 ++-- 0_temp_helios_test/stage-1_i2v.sh | 4 ++-- 0_temp_helios_test/stage-1_t2v.sh | 4 ++-- 0_temp_helios_test/stage-1_v2v.sh | 4 ++-- 0_temp_helios_test/stage-2_i2v.sh | 4 ++-- 0_temp_helios_test/stage-2_t2v.sh | 4 ++-- 0_temp_helios_test/stage-2_v2v.sh | 4 ++-- 0_temp_helios_test/stage-3_i2v.sh | 4 ++-- 0_temp_helios_test/stage-3_t2v.sh | 4 ++-- 0_temp_helios_test/stage-3_v2v.sh | 4 ++-- 10 files changed, 20 insertions(+), 20 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 43fbabf47392..02e01e08281a 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -25,11 +25,11 @@ def parse_args(): parser = argparse.ArgumentParser(description="Generate video with model") # === Model paths === - parser.add_argument("--base_model_path", type=str, default="/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base") + parser.add_argument("--base_model_path", type=str, default="BestWishYsh/Helios-Base") parser.add_argument( "--transformer_path", type=str, - default="/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base", + default="BestWishYsh/Helios-Base", ) parser.add_argument( "--lora_path", diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh index 36da9461fea9..e601991b09df 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --base_model_path "BestWishYsh/Helios-Base" \ + --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ diff --git a/0_temp_helios_test/stage-1_t2v.sh b/0_temp_helios_test/stage-1_t2v.sh index 8b2fea961533..6520b8a0713a 100644 --- a/0_temp_helios_test/stage-1_t2v.sh +++ b/0_temp_helios_test/stage-1_t2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --base_model_path "BestWishYsh/Helios-Base" \ + --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --use_dynamic_shifting \ diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh index 349c42d110da..e18920f456ad 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Base" \ + --base_model_path "BestWishYsh/Helios-Base" \ + --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 56b3ad46d1a2..e3cd112c2880 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --base_model_path "BestWishYsh/Helios-Mid" \ + --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index 078b7ebaf101..780fc3d18227 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --base_model_path "BestWishYsh/Helios-Mid" \ + --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --is_enable_stage2 \ diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 9158e9a7f68b..62455ca1623e 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Mid" \ + --base_model_path "BestWishYsh/Helios-Mid" \ + --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index 340c21ab639d..78416244b983 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --base_model_path "BestWishYsh/Helios-Distilled" \ + --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/stage-3_t2v.sh index 5240f3cbeff6..634217b50990 100644 --- a/0_temp_helios_test/stage-3_t2v.sh +++ b/0_temp_helios_test/stage-3_t2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --base_model_path "BestWishYsh/Helios-Distilled" \ + --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --num_frames 240 \ diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index 1be947dd0f00..27affdf31d6c 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ - --transformer_path "/mnt/bn/yufan-dev-my/ysh_new/Codes/Helios/1_formal_ckpts/Helios-Distilled" \ + --base_model_path "BestWishYsh/Helios-Distilled" \ + --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ From 3a04f67a401ad6c6c78ed15d2ab9c9513f0a3c99 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 24 Feb 2026 15:14:20 +0000 Subject: [PATCH 005/107] fix cus script --- 0_temp_helios_test/infer_helios.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 02e01e08281a..dbd7dbf1ccdf 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -62,7 +62,7 @@ def parse_args(): # base parser.add_argument("--height", type=int, default=384) parser.add_argument("--width", type=int, default=640) - parser.add_argument("--num_frames", type=int, default=73) + parser.add_argument("--num_frames", type=int, default=99) parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=5.0) parser.add_argument("--use_dynamic_shifting", action="store_true") @@ -238,7 +238,10 @@ def main(): pipe.vae.enable_tiling() if args.enable_compile: - pipe.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True) + torch.backends.cudnn.benchmark = True + pipe.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False) + pipe.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False) + pipe.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) if args.low_vram_mode: pipe.enable_group_offload( @@ -254,7 +257,7 @@ def main(): pipe = pipe.to(device) if world_size > 1 and args.enable_parallelism: - transformer.set_attention_backend("flash") + # transformer.set_attention_backend("flash") pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size)) if args.debug_mode: From d693c046595acc0ba566f865ffcdf0f2f80c2076 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 05:49:59 +0000 Subject: [PATCH 006/107] update docs --- 0_temp_helios_test/stage-1_i2v.sh | 1 + 0_temp_helios_test/stage-1_v2v.sh | 1 + 0_temp_helios_test/stage-2_i2v.sh | 3 + 0_temp_helios_test/stage-2_t2v.sh | 2 + 0_temp_helios_test/stage-2_v2v.sh | 3 + 0_temp_helios_test/stage-3_i2v.sh | 3 +- 0_temp_helios_test/stage-3_t2v.sh | 3 +- 0_temp_helios_test/stage-3_v2v.sh | 3 +- docs/source/en/using-diffusers/consisid.md | 2 +- docs/source/en/using-diffusers/helios.md | 67 ++++++++++++++++++++++ docs/source/zh/using-diffusers/helios.md | 67 ++++++++++++++++++++++ 11 files changed, 151 insertions(+), 4 deletions(-) diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh index e601991b09df..d660198b1c9b 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -7,6 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" + # --use_default_loader \ # --enable_compile \ # --use_cfg_zero_star \ diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh index e18920f456ad..20177851204d 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -7,6 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" + # --use_default_loader \ # --enable_compile \ # --use_cfg_zero_star \ diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index e3cd112c2880..7ba8a9520ea7 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -5,11 +5,14 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --is_enable_stage2 \ + --stage2_num_inference_steps_list 20 20 20 \ --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" + + # --stage2_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index 780fc3d18227..c6fd3176e05d 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -4,6 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --is_enable_stage2 \ + --stage2_num_inference_steps_list 20 20 20 \ --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ @@ -11,5 +12,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --output_folder "./output_helios/stage-2" + # --stage2_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 62455ca1623e..167edeca6cba 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -5,11 +5,14 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --is_enable_stage2 \ + --stage2_num_inference_steps_list 20 20 20 \ --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" + + # --stage2_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index 78416244b983..4735788b1176 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -13,5 +13,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --use_default_loader \ + + # --stage2_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/stage-3_t2v.sh index 634217b50990..29a74c371e0e 100644 --- a/0_temp_helios_test/stage-3_t2v.sh +++ b/0_temp_helios_test/stage-3_t2v.sh @@ -12,5 +12,6 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --use_default_loader \ + + # --stage2_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index 27affdf31d6c..088004a15a63 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -13,5 +13,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --use_default_loader \ + + # --stage2_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md index b6b04ddaf57e..96ece5b20c3a 100644 --- a/docs/source/en/using-diffusers/consisid.md +++ b/docs/source/en/using-diffusers/consisid.md @@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8) Face Image Video - DescriptionDescription diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index 93b276c29d2d..6a65ba24141d 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -48,16 +48,83 @@ pipe.to("cuda") ## Text-to-Video Showcases + + + + + + + + + + + +
PromptGenerated Video
A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression. + + +
A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle. + + +
## Image-to-Video Showcases + + + + + + + + + + + + + + +
ImagePromptGenerated Video
A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. + + +
A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. + + +
## Interactive-Video Showcases + + + + + + + + + + + +
PromptGenerated Video
The prompt can be found here + +
The prompt can be found here + +
## Resources diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 17447ed809eb..9e1ac69bb855 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -48,16 +48,83 @@ pipe.to("cuda") ## Text-to-Video Showcases + + + + + + + + + + + +
PromptGenerated Video
A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression. + + +
A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle. + + +
## Image-to-Video Showcases + + + + + + + + + + + + + + +
ImagePromptGenerated Video
A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. + + +
A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. + + +
## Interactive-Video Showcases + + + + + + + + + + + +
PromptGenerated Video
The prompt can be found here + +
The prompt can be found here + +
## Resources From 44b04ddf9ef719b962ddeed76949098f25575c2f Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 09:36:04 +0000 Subject: [PATCH 007/107] fix documented check --- docs/source/en/api/loaders/lora.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index bbae6a9020af..db1ea884558f 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -23,6 +23,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). - [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video). - [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana). +- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios). - [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video). - [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). - [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). @@ -86,6 +87,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin +## HeliosLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin + ## HunyuanVideoLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin From 50ab6a540509d1a8682543b59d723a70466657d4 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 09:53:04 +0000 Subject: [PATCH 008/107] update links for docs and examples --- 0_temp_helios_test/stage-1_i2v.sh | 2 +- 0_temp_helios_test/stage-1_v2v.sh | 2 +- 0_temp_helios_test/stage-2_i2v.sh | 2 +- 0_temp_helios_test/stage-2_v2v.sh | 2 +- 0_temp_helios_test/stage-3_i2v.sh | 2 +- 0_temp_helios_test/stage-3_v2v.sh | 2 +- docs/source/en/api/pipelines/helios.md | 12 ++++++------ docs/source/en/using-diffusers/helios.md | 20 ++++++++++---------- docs/source/zh/using-diffusers/helios.md | 20 ++++++++++---------- 9 files changed, 32 insertions(+), 32 deletions(-) diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh index d660198b1c9b..335cc832ae4a 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Base" \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh index 20177851204d..56a00631a988 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Base" \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 7ba8a9520ea7..79a6515b5a24 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Mid" \ --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --is_enable_stage2 \ --stage2_num_inference_steps_list 20 20 20 \ diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 167edeca6cba..261fdeed09b6 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Mid" \ --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --is_enable_stage2 \ --stage2_num_inference_steps_list 20 20 20 \ diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index 4735788b1176..de40fc798414 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Distilled" \ --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" \ + --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --num_frames 240 \ --guidance_scale 1.0 \ diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index 088004a15a63..f4b001908e17 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Distilled" \ --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --num_frames 240 \ --guidance_scale 1.0 \ diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index e59c38211663..a8f2c14104db 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -202,7 +202,7 @@ apparent, revealing the restless expanse of the ocean stretching beyond. The sce relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might. """ -image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" output = pipeline( prompt=prompt, @@ -222,7 +222,7 @@ emphasizing its dynamic movement. The road curves gently, with a guardrail visib the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. """ -video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" output = pipeline( prompt=prompt, @@ -298,7 +298,7 @@ apparent, revealing the restless expanse of the ocean stretching beyond. The sce relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might. """ -image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" output = pipeline( prompt=prompt, @@ -323,7 +323,7 @@ emphasizing its dynamic movement. The road curves gently, with a guardrail visib the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. """ -video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" output = pipeline( prompt=prompt, @@ -404,7 +404,7 @@ apparent, revealing the restless expanse of the ocean stretching beyond. The sce relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might. """ -image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/wave.jpg" +image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" output = pipeline( prompt=prompt, @@ -429,7 +429,7 @@ emphasizing its dynamic movement. The road curves gently, with a guardrail visib the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery. """ -video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/refs%2Fpr%2F587/diffusers/helios/car.mp4" +video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" output = pipeline( prompt=prompt, diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index 6a65ba24141d..0f9ea997e248 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -57,7 +57,7 @@ pipe.to("cuda") @@ -66,7 +66,7 @@ pipe.to("cuda") @@ -81,22 +81,22 @@ pipe.to("cuda") Generated Video - + A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. - + A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. @@ -110,18 +110,18 @@ pipe.to("cuda") Generated Video - The prompt can be found here + The prompt can be found here - The prompt can be found here + The prompt can be found here diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 9e1ac69bb855..5e18f02bfe85 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -57,7 +57,7 @@ pipe.to("cuda") @@ -66,7 +66,7 @@ pipe.to("cuda") @@ -81,22 +81,22 @@ pipe.to("cuda") Generated Video - + A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees. - + A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective. @@ -110,18 +110,18 @@ pipe.to("cuda") Generated Video - The prompt can be found here + The prompt can be found here - The prompt can be found here + The prompt can be found here From 75809c6153804f0491a159fff1d57b97cdb78156 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 11:00:04 +0000 Subject: [PATCH 009/107] change default config --- .../models/transformers/transformer_helios.py | 6 +++--- src/diffusers/pipelines/helios/pipeline_helios.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index fb1e0f15cf87..94e2142801c1 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -671,9 +671,9 @@ def __init__( qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, - guidance_cross_attn: bool = False, - zero_history_timestep: bool = False, - has_multi_term_memory_patch: bool = False, + guidance_cross_attn: bool = True, + zero_history_timestep: bool = True, + has_multi_term_memory_patch: bool = True, is_amplify_history: bool = False, history_scale_mode: str = "per_head", # [scalar, per_head] ) -> None: diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index c59c809f4ef8..d4f770af9776 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -407,9 +407,9 @@ def prepare_latents( self, batch_size: int, num_channels_latents: int = 16, - height: int = 480, - width: int = 832, - num_frames: int = 81, + height: int = 384, + width: int = 640, + num_frames: int = 33, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -891,7 +891,7 @@ def __call__( negative_prompt: Union[str, List[str]] = None, height: int = 384, width: int = 640, - num_frames: int = 73, + num_frames: int = 132, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, @@ -954,11 +954,11 @@ def __call__( negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale` < `1`). - height (`int`, defaults to `480`): + height (`int`, defaults to `384`): The height in pixels of the generated image. - width (`int`, defaults to `832`): + width (`int`, defaults to `640`): The width in pixels of the generated image. - num_frames (`int`, defaults to `81`): + num_frames (`int`, defaults to `132`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the From 608f22ce523a72f171c15611d28d5b9611d014cf Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 13:03:47 +0000 Subject: [PATCH 010/107] small refactor --- .../models/transformers/transformer_helios.py | 11 +++++++---- src/diffusers/pipelines/helios/pipeline_helios.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 94e2142801c1..25b90716ad1b 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -658,7 +658,7 @@ class HeliosTransformer3DModel( @register_to_config def __init__( self, - patch_size: Tuple[int] = (1, 2, 2), + patch_size: tuple[int, ...] = (1, 2, 2), num_attention_heads: int = 40, attention_head_dim: int = 128, in_channels: int = 16, @@ -668,9 +668,12 @@ def __init__( ffn_dim: int = 13824, num_layers: int = 40, cross_attn_norm: bool = True, - qk_norm: Optional[str] = "rms_norm_across_heads", + qk_norm: str | None = "rms_norm_across_heads", eps: float = 1e-6, - added_kv_proj_dim: Optional[int] = None, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_dim: tuple[int, ...] = (44, 42, 42), + rope_theta: float = 10000.0, guidance_cross_attn: bool = True, zero_history_timestep: bool = True, has_multi_term_memory_patch: bool = True, @@ -683,7 +686,7 @@ def __init__( out_channels = out_channels or in_channels # 1. Patch & position embedding - self.rope = HeliosRotaryPosEmbed(rope_dim=(44, 42, 42), theta=10000.0) + self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) # 2. Initial Multi Term Memory Patch diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index d4f770af9776..9aee51c27e58 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -223,7 +223,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, - scheduler: Union[UniPCMultistepScheduler, HeliosUniPCScheduler], + scheduler: UniPCMultistepScheduler | HeliosUniPCScheduler, transformer: HeliosTransformer3DModel, ): super().__init__() From 31fb0cfdc56ca9ee95a9b36ddfbec75ab425a878 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 25 Feb 2026 13:04:00 +0000 Subject: [PATCH 011/107] add test --- tests/lora/test_lora_layers_helios.py | 120 +++++++++++ .../test_models_transformer_helios.py | 194 ++++++++++++++++++ 2 files changed, 314 insertions(+) create mode 100644 tests/lora/test_lora_layers_helios.py diff --git a/tests/lora/test_lora_layers_helios.py b/tests/lora/test_lora_layers_helios.py new file mode 100644 index 000000000000..fbcc3b808eee --- /dev/null +++ b/tests/lora/test_lora_layers_helios.py @@ -0,0 +1,120 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, HeliosPipeline, HeliosTransformer3DModel + +from ..testing_utils import floats_tensor, require_peft_backend, skip_mps + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +@skip_mps +class HeliosLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = HeliosPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + transformer_cls = HeliosTransformer3DModel + vae_kwargs = { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + vae_cls = AutoencoderKLWan + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + text_encoder_target_modules = ["q", "k", "v", "o"] + + supports_text_encoder_loras = False + + @property + def output_shape(self): + return (1, 33, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 9 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_frames": num_frames, + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in Helios.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Helios.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Helios.") + def test_modify_padding_mode(self): + pass diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index e69de29bb2d1..73fe7a99b86d 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -0,0 +1,194 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import HeliosTransformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class HeliosTransformer3DTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return HeliosTransformer3DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-helios-base-transformer" + + @property + def output_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def input_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]: + return { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + indices_hidden_states = torch.ones((2,)).to(torch_device) + indices_latents_history_short = torch.ones((num_frames - 1,)).to(torch_device) + indices_latents_history_mid = torch.ones((num_frames - 1,)).to(torch_device) + indices_latents_history_long = torch.ones(((num_frames - 1) * 4,)).to(torch_device) + latents_history_short = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) + latents_history_mid = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) + latents_history_long = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to( + torch_device + ) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "indices_hidden_states": indices_hidden_states, + "indices_latents_history_short": indices_latents_history_short, + "indices_latents_history_mid": indices_latents_history_mid, + "indices_latents_history_long": indices_latents_history_long, + "latents_history_short": latents_history_short, + "latents_history_mid": latents_history_mid, + "latents_history_long": latents_history_long, + } + + +class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): + """Core model tests for Helios Transformer 3D.""" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # Skip: fp16/bf16 require very high atol to pass, providing little signal. + # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Helios Transformer 3D.""" + + +class TestHeliosTransformer3DTraining(HeliosTransformer3DTesterConfig, TrainingTesterMixin): + """Training tests for Helios Transformer 3D.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HeliosTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, AttentionTesterMixin): + """Attention processor tests for Helios Transformer 3D.""" + + +class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for Helios Transformer 3D.""" + + +class TestHeliosTransformer3DBitsAndBytes(HeliosTransformer3DTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for Helios Transformer 3D.""" + + @property + def torch_dtype(self): + return torch.float16 + + +class TestHeliosTransformer3DTorchAo(HeliosTransformer3DTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for Helios Transformer 3D.""" + + @property + def torch_dtype(self): + return torch.bfloat16 + + +# class TestHeliosTransformer3DGGUF(HeliosTransformer3DTesterConfig, GGUFTesterMixin): +# """GGUF quantization tests for Helios Transformer 3D.""" + +# @property +# def gguf_filename(self): +# return "" + +# @property +# def torch_dtype(self): +# return torch.bfloat16 + +# def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): +# return super()._create_quantized_model( +# config_kwargs, config="BestWishYsh/Helios-Base", subfolder="transformer", **extra_kwargs +# ) + + +# class TestHeliosTransformer3DGGUFCompile(HeliosTransformer3DTesterConfig, GGUFCompileTesterMixin): +# """GGUF + compile tests for Helios Transformer 3D.""" + +# @property +# def gguf_filename(self): +# return "" + +# @property +# def torch_dtype(self): +# return torch.bfloat16 + +# def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): +# return super()._create_quantized_model( +# config_kwargs, config="BestWishYsh/Helios-Base", subfolder="transformer", **extra_kwargs +# ) From f276d2ff84c3ac60ec0349318049fd032abe3687 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 26 Feb 2026 09:18:36 +0800 Subject: [PATCH 012/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 25b90716ad1b..4a5a54de6149 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -142,14 +142,6 @@ def __call__( return hidden_states -class HeliosAttnProcessor2_0: - def __new__(cls, *args, **kwargs): - deprecation_message = ( - "The HeliosAttnProcessor2_0 class is deprecated and will be removed in a future version. " - "Please use HeliosAttnProcessor instead. " - ) - deprecate("HeliosAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) - return HeliosAttnProcessor(*args, **kwargs) class HeliosAttention(torch.nn.Module, AttentionModuleMixin): From 061a4429336175493c74f8c9fd4b97a30e590294 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Thu, 26 Feb 2026 01:34:13 +0000 Subject: [PATCH 013/107] remove register_buffer for _scale_cache --- .../models/transformers/transformer_helios.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 4a5a54de6149..5487aa03b2bc 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import apply_lora_scale, deprecate, logging +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -118,7 +118,7 @@ def __call__( history_seq_len = hidden_states.shape[1] - original_context_length if history_seq_len > 0: - scale_key = attn.get_scale_key() + scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0) if attn.history_scale_mode == "per_head": scale_key = scale_key.view(1, 1, -1, 1) key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1) @@ -142,8 +142,6 @@ def __call__( return hidden_states - - class HeliosAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = HeliosAttnProcessor _available_processors = [HeliosAttnProcessor] @@ -205,16 +203,6 @@ def __init__( raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}") self.history_scale_mode = history_scale_mode self.max_scale = 10.0 - self.register_buffer("_scale_cache", None) - - def get_scale_key(self): - if self.history_key_scale.requires_grad: - scale = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0) - else: - if self._scale_cache is None: - self._scale_cache = 1.0 + torch.sigmoid(self.history_key_scale) * (self.max_scale - 1.0) - scale = self._scale_cache - return scale def fuse_projections(self): if getattr(self, "fused_projections", False): From ba05daf6f2d5f102102a5266ce00c6b91e16fecd Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Thu, 26 Feb 2026 01:36:45 +0000 Subject: [PATCH 014/107] fix non-cuda devices error --- src/diffusers/pipelines/helios/pipeline_helios.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 9aee51c27e58..ec01408051bf 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -78,17 +78,13 @@ """ -@torch.amp.autocast("cuda", dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - # Squared norm of uncondition squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm - return st_star From bc3c5282002a68b4070b8ae5dcac6e480f3464cf Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 01:46:24 +0000 Subject: [PATCH 015/107] remove "handle the case when timestep is 2D" --- .../models/transformers/transformer_helios.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 5487aa03b2bc..21e7a3dbceca 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -295,26 +295,16 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, is_return_encoder_hidden_states: bool = True, ): - B = None - F = None - if timestep.ndim == 2: - B, F = timestep.shape - timestep = timestep.flatten() - - timestep = self.timesteps_proj(timestep) # torch.Size([2]) -> torch.Size([2, 256]) + timestep = self.timesteps_proj(timestep) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) # torch.Size([2, 1536]) - timestep_proj = self.time_proj(self.act_fn(temb)) # torch.Size([2, 9216] - - if B is not None and F is not None: - temb = temb.reshape(B, F, -1) - timestep_proj = timestep_proj.reshape(B, F, -1) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) if encoder_hidden_states is not None and is_return_encoder_hidden_states: - encoder_hidden_states = self.text_embedder(encoder_hidden_states) # torch.Size([2, 512, 1536]) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) return temb, timestep_proj, encoder_hidden_states From 250805858786c05c2c19a6961a03c264ca31ddf6 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 02:52:24 +0000 Subject: [PATCH 016/107] refactor HeliosMultiTermMemoryPatch and process_input_hidden_states --- 0_temp_helios_test/requirements.txt | 2 +- .../models/transformers/transformer_helios.py | 247 +++++++----------- 2 files changed, 89 insertions(+), 160 deletions(-) diff --git a/0_temp_helios_test/requirements.txt b/0_temp_helios_test/requirements.txt index b1f777548ab3..d3c12e60923a 100644 --- a/0_temp_helios_test/requirements.txt +++ b/0_temp_helios_test/requirements.txt @@ -5,7 +5,7 @@ triton==3.3.1 # diffusers==0.36.0 # transformers==4.57.6 # sentence-transformers==5.2.3 -git+https://github.com/SHYuanBest/diffusers.git@test +# git+https://github.com/SHYuanBest/diffusers.git@test git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/sentence-transformers.git accelerate==1.12.0 diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 21e7a3dbceca..7aca91c75e94 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -309,86 +309,6 @@ def forward( return temb, timestep_proj, encoder_hidden_states -class HeliosMultiTermMemoryPatch(nn.Module): - def __init__(self, in_channels, inner_dim): - super().__init__() - self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - self.patch_mid = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) - self.patch_long = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) - - def forward( - self, - hidden_states, - rope_freqs, - rope_fn, - indices_latents_history_short=None, - indices_latents_history_mid=None, - indices_latents_history_long=None, - latents_history_short=None, - latents_history_mid=None, - latents_history_long=None, - ): - # Process clean latents (1x) - if latents_history_short is not None and indices_latents_history_short is not None: - latents_history_short = latents_history_short.to(hidden_states) - latents_history_short = self.patch_short(latents_history_short) - _, _, _, H1, W1 = latents_history_short.shape - latents_history_short = latents_history_short.flatten(2).transpose(1, 2) - - rope_freqs_history_short = rope_fn( - frame_indices=indices_latents_history_short, - height=H1, - width=W1, - device=latents_history_short.device, - ) - rope_freqs_history_short = rope_freqs_history_short.flatten(2).transpose(1, 2) - - hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) - rope_freqs = torch.cat([rope_freqs_history_short, rope_freqs], dim=1) - - # Process 2x history latents - if latents_history_mid is not None and indices_latents_history_mid is not None: - latents_history_mid = latents_history_mid.to(hidden_states) - latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) - latents_history_mid = self.patch_mid(latents_history_mid) - latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) - - rope_freqs_history_mid = rope_fn( - frame_indices=indices_latents_history_mid, - height=H1, - width=W1, - device=latents_history_mid.device, - ) - rope_freqs_history_mid = pad_for_3d_conv(rope_freqs_history_mid, (2, 2, 2)) - rope_freqs_history_mid = center_down_sample_3d(rope_freqs_history_mid, (2, 2, 2)) - rope_freqs_history_mid = rope_freqs_history_mid.flatten(2).transpose(1, 2) - - hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) - rope_freqs = torch.cat([rope_freqs_history_mid, rope_freqs], dim=1) - - # Process 4x history latents - if latents_history_long is not None and indices_latents_history_long is not None: - latents_history_long = latents_history_long.to(hidden_states) - latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) - latents_history_long = self.patch_long(latents_history_long) - latents_history_long = latents_history_long.flatten(2).transpose(1, 2) - - rope_freqs_history_long = rope_fn( - frame_indices=indices_latents_history_long, - height=H1, - width=W1, - device=latents_history_long.device, - ) - rope_freqs_history_long = pad_for_3d_conv(rope_freqs_history_long, (4, 4, 4)) - rope_freqs_history_long = center_down_sample_3d(rope_freqs_history_long, (4, 4, 4)) - rope_freqs_history_long = rope_freqs_history_long.flatten(2).transpose(1, 2) - - hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) - rope_freqs = torch.cat([rope_freqs_history_long, rope_freqs], dim=1) - - return hidden_states, rope_freqs - - class HeliosRotaryPosEmbed(nn.Module): def __init__(self, rope_dim, theta): super().__init__() @@ -663,7 +583,9 @@ def __init__( self.zero_history_timestep = zero_history_timestep self.inner_dim = inner_dim if has_multi_term_memory_patch: - self.multi_term_memory_patch = HeliosMultiTermMemoryPatch(in_channels, self.inner_dim) + self.patch_short = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.patch_mid = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.patch_long = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) # 3. Condition embeddings self.condition_embedder = HeliosTimeTextEmbedding( @@ -699,67 +621,6 @@ def __init__( self.gradient_checkpointing = False - def process_input_hidden_states( - self, - latents, - indices_hidden_states=None, - indices_latents_history_short=None, - indices_latents_history_mid=None, - indices_latents_history_long=None, - latents_history_short=None, - latents_history_mid=None, - latents_history_long=None, - ): - height_list = [] - width_list = [] - temporal_list = [] - seq_list = [] - - hidden_states = self.patch_embedding(latents) - B, C, T, H, W = hidden_states.shape - - if indices_hidden_states is None: - indices_hidden_states = torch.arange(0, T).unsqueeze(0).expand(B, -1) - - hidden_states = hidden_states.flatten(2).transpose( - 1, 2 - ) # torch.Size([1, 3072, 9, 44, 34]) -> torch.Size([1, 13464, 3072]) - - rope_freqs = self.rope( - frame_indices=indices_hidden_states, - height=H, - width=W, - device=hidden_states.device, - ) # torch.Size([1, 9]) -> torch.Size([1, 256, 9, 44, 34]) - rope_freqs = rope_freqs.flatten(2).transpose(1, 2) # torch.Size([1, 13464, 256]) - - height_list.append(H) - width_list.append(W) - temporal_list.append(T) - seq_list.append(hidden_states.shape[1]) - - if latents_history_short is not None: - hidden_states, rope_freqs = self.multi_term_memory_patch( - hidden_states=hidden_states, - rope_freqs=rope_freqs, - rope_fn=self.rope, - latents_history_short=latents_history_short, - indices_latents_history_short=indices_latents_history_short, - latents_history_mid=latents_history_mid, - indices_latents_history_mid=indices_latents_history_mid, - latents_history_long=latents_history_long, - indices_latents_history_long=indices_latents_history_long, - ) - - return ( - hidden_states, - rope_freqs, - height_list, - width_list, - temporal_list, - seq_list, - ) - @apply_lora_scale("attention_kwargs") def forward( self, @@ -807,23 +668,91 @@ def forward( batch_size = hidden_states.shape[0] p_t, p_h, p_w = self.config.patch_size - ( - hidden_states, - rotary_emb, - post_patch_height_list, - post_patch_width_list, - post_patch_num_frames_list, - original_context_length_list, - ) = self.process_input_hidden_states( - latents=hidden_states, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short, - latents_history_mid=latents_history_mid, - latents_history_long=latents_history_long, - ) # hidden: [high, mid, low] -> [low, mid, high] + # hidden: [high, mid, low] -> [low, mid, high] + post_patch_height_list = [] + post_patch_width_list = [] + post_patch_num_frames_list = [] + original_context_length_list = [] + + # Process noisy latents + hidden_states = self.patch_embedding(hidden_states) + B, C, T, H, W = hidden_states.shape + + if indices_hidden_states is None: + indices_hidden_states = torch.arange(0, T).unsqueeze(0).expand(B, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + rotary_emb = self.rope( + frame_indices=indices_hidden_states, + height=H, + width=W, + device=hidden_states.device, + ) + rotary_emb = rotary_emb.flatten(2).transpose(1, 2) + + post_patch_height_list.append(H) + post_patch_width_list.append(W) + post_patch_num_frames_list.append(T) + original_context_length_list.append(hidden_states.shape[1]) + + # Process short history latents + if latents_history_short is not None and indices_latents_history_short is not None: + latents_history_short = latents_history_short.to(hidden_states) + latents_history_short = self.patch_short(latents_history_short) + _, _, _, H1, W1 = latents_history_short.shape + latents_history_short = latents_history_short.flatten(2).transpose(1, 2) + + rotary_emb_history_short = self.rope( + frame_indices=indices_latents_history_short, + height=H1, + width=W1, + device=latents_history_short.device, + ) + rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1) + + # Process mid history latents + if latents_history_mid is not None and indices_latents_history_mid is not None: + latents_history_mid = latents_history_mid.to(hidden_states) + latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) + latents_history_mid = self.patch_mid(latents_history_mid) + latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) + + rotary_emb_history_mid = self.rope( + frame_indices=indices_latents_history_mid, + height=H1, + width=W1, + device=latents_history_mid.device, + ) + rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1) + + # Process long history latents + if latents_history_long is not None and indices_latents_history_long is not None: + latents_history_long = latents_history_long.to(hidden_states) + latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) + latents_history_long = self.patch_long(latents_history_long) + latents_history_long = latents_history_long.flatten(2).transpose(1, 2) + + rotary_emb_history_long = self.rope( + frame_indices=indices_latents_history_long, + height=H1, + width=W1, + device=latents_history_long.device, + ) + rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1) + post_patch_num_frames = sum(post_patch_num_frames_list) post_patch_height = sum(post_patch_height_list) post_patch_width = sum(post_patch_width_list) From 96747b7a5b932957a3249435c5b0b6effce515af Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:54:15 +0800 Subject: [PATCH 017/107] Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/helios/pipeline_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index ec01408051bf..94babd3df285 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -105,13 +105,13 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, - exp_max=7.0, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len From 37eb8f06d4b73d55d43b0484969a93aa6466fff2 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:55:00 +0800 Subject: [PATCH 018/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_helios.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 7aca91c75e94..e7ebe7c1fff3 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -560,7 +560,6 @@ def __init__( cross_attn_norm: bool = True, qk_norm: str | None = "rms_norm_across_heads", eps: float = 1e-6, - image_dim: int | None = None, added_kv_proj_dim: int | None = None, rope_dim: tuple[int, ...] = (44, 42, 42), rope_theta: float = 10000.0, From 439b76dba4ea63c17c51a4accb00988e75f651df Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:28:51 +0800 Subject: [PATCH 019/107] Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/helios/pipeline_helios.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 94babd3df285..84521d537d9e 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -278,8 +278,8 @@ def _get_t5_prompt_embeds( def encode_prompt( self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, From bb2482a213c552974dc52963062bcb42fe05abf3 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 05:29:07 +0000 Subject: [PATCH 020/107] fix calculate_shift --- src/diffusers/pipelines/helios/pipeline_helios.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 84521d537d9e..ceac7db0eb15 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -141,7 +141,6 @@ def apply_schedule_shift( max_seq_len if max_seq_len is not None else 4096, base_shift if base_shift is not None else 0.5, max_shift if max_shift is not None else 1.15, - exp_max if exp_max is not None else 7.0, ) if is_exponential: mu = min(mu, math.log(exp_max)) From 36809980366d4094c7a5dc617233241e1b2e1606 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Thu, 26 Feb 2026 14:34:13 +0800 Subject: [PATCH 021/107] Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/helios/pipeline_helios.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index ceac7db0eb15..723d322dacc0 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -134,13 +134,12 @@ def apply_schedule_shift( ): if mu is None: # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper - image_seq_len = (noise.shape[-1] * noise.shape[-2] * noise.shape[-3]) // 4 # patch size 1,2,2 mu = calculate_shift( image_seq_len, - base_seq_len if base_seq_len is not None else 256, - max_seq_len if max_seq_len is not None else 4096, - base_shift if base_shift is not None else 0.5, - max_shift if max_shift is not None else 1.15, + base_seq_len, + max_seq_len, + base_shift, + max_shift, ) if is_exponential: mu = min(mu, math.log(exp_max)) From 9b4fbe785b1ed8959d71ff1964db8ecc2b7df680 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 06:54:10 +0000 Subject: [PATCH 022/107] rewritten `einops` in pure `torch` --- .../pipelines/helios/pipeline_helios.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 723d322dacc0..33f7c3658a0c 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -21,7 +21,6 @@ import regex as re import torch import torch.nn.functional as F -from einops import rearrange from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -134,6 +133,7 @@ def apply_schedule_shift( ): if mu is None: # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper + image_seq_len = (noise.shape[-1] * noise.shape[-2] * noise.shape[-3]) // 4 # patch size 1,2,2 mu = calculate_shift( image_seq_len, base_seq_len, @@ -682,12 +682,8 @@ def stage2_sample( callback_on_step_end_tensor_inputs: list = None, progress_bar=None, ): - num_frmaes, height, width = ( - latents.shape[-3], - latents.shape[-2], - latents.shape[-1], - ) - latents = rearrange(latents, "b c t h w -> (b t) c h w") + batch_size, num_channel, num_frmaes, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frmaes, num_channel, height, width) for _ in range(stage2_num_stages - 1): height //= 2 width //= 2 @@ -699,7 +695,7 @@ def stage2_sample( ) * 2 ) - latents = rearrange(latents, "(b t) c h w -> b c t h w", t=num_frmaes) + latents = latents.reshape(batch_size, num_frmaes, num_channel, height, width).permute(0, 2, 1, 3, 4) batch_size = latents.shape[0] if use_dmd: @@ -721,9 +717,11 @@ def stage2_sample( height *= 2 width *= 2 num_frames = latents.shape[2] - latents = rearrange(latents, "b c t h w -> (b t) c h w") + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frmaes, num_channel, height // 2, width // 2 + ) latents = F.interpolate(latents, size=(height, width), mode="nearest") - latents = rearrange(latents, "(b t) c h w -> b c t h w", t=num_frames) + latents = latents.reshape(batch_size, num_frmaes, num_channel, height, width).permute(0, 2, 1, 3, 4) # Fix the stage ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal gamma = self.scheduler.config.gamma From cee81f8111856aa2ca748fb3ba2d0eaeef73b318 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 07:08:01 +0000 Subject: [PATCH 023/107] fix: pass patch_size to apply_schedule_shift instead of hardcoding --- .../pipelines/helios/pipeline_helios.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 33f7c3658a0c..a16ffbea431b 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -119,8 +119,8 @@ def calculate_shift( def apply_schedule_shift( + image_seq_len, sigmas, - noise, sigmas_two=None, base_seq_len: int = 256, max_seq_len: int = 4096, @@ -133,7 +133,6 @@ def apply_schedule_shift( ): if mu is None: # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper - image_seq_len = (noise.shape[-1] * noise.shape[-2] * noise.shape[-3]) // 4 # patch size 1,2,2 mu = calculate_shift( image_seq_len, base_seq_len, @@ -737,9 +736,13 @@ def stage2_sample( start_point_list.append(latents) if use_dynamic_shifting: + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) temp_sigmas = apply_schedule_shift( + image_seq_len, self.scheduler.sigmas, - latents, base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), base_shift=self.scheduler.config.get("base_shift", 0.5), @@ -1319,12 +1322,16 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) if use_dynamic_shifting: + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) sigmas = torch.linspace( 0.999, 0.0, steps=num_inference_steps + 1, dtype=torch.float32, device=device )[:-1] sigmas = apply_schedule_shift( - sigmas=sigmas, - noise=latents, + image_seq_len, + sigmas, base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), base_shift=self.scheduler.config.get("base_shift", 0.5), From a7960fea8bd2e6f097547d1417dbec6f81651b19 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 07:17:55 +0000 Subject: [PATCH 024/107] remove the logics of 'vae_decode_type' --- 0_temp_helios_test/infer_helios.py | 9 --- .../pipelines/helios/pipeline_helios.py | 68 ++----------------- 2 files changed, 6 insertions(+), 71 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index dbd7dbf1ccdf..7e1d17b25950 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -66,7 +66,6 @@ def parse_args(): parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=5.0) parser.add_argument("--use_dynamic_shifting", action="store_true") - parser.add_argument("--vae_decode_type", type=str, default="default", choices=["default", "once", "default_fast"]) # cfg zero parser.add_argument("--use_cfg_zero_star", action="store_true") parser.add_argument("--use_zero_init", action="store_true") @@ -234,9 +233,6 @@ def main(): pipe.load_lora_weights(args.lora_path, adapter_name="default") pipe.set_adapters(["default"], adapter_weights=[1.0]) - if args.vae_decode_type == "once": - pipe.vae.enable_tiling() - if args.enable_compile: torch.backends.cudnn.benchmark = True pipe.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False) @@ -295,7 +291,6 @@ def parse_list_input(input_string): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), - vae_decode_type=args.vae_decode_type, # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, @@ -358,7 +353,6 @@ def parse_list_input(input_string): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), - vae_decode_type=args.vae_decode_type, # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, @@ -429,7 +423,6 @@ def parse_list_input(input_string): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), - vae_decode_type=args.vae_decode_type, # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, @@ -489,7 +482,6 @@ def parse_list_input(input_string): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), - vae_decode_type=args.vae_decode_type, # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, @@ -534,7 +526,6 @@ def parse_list_input(input_string): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), - vae_decode_type=args.vae_decode_type, # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index a16ffbea431b..2d8ed2a29a19 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -14,7 +14,6 @@ import html import math -from enum import Enum from itertools import accumulate from typing import Any, Callable, Dict, List, Optional, Union @@ -180,11 +179,6 @@ def convert_flow_pred_to_x0(flow_pred, xt, timestep, sigmas, timesteps): return x0_pred.to(original_dtype) -class VAEDecodeType(str, Enum): - DEFAULT = "default" - DEFAULT_BATCH = "default_batch" - - class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): r""" Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. @@ -937,8 +931,6 @@ def __call__( use_dmd: bool = False, is_skip_first_section: bool = False, is_amplify_first_chunk: bool = False, - # ------------ other ------------ - vae_decode_type: VAEDecodeType = "default", # "default", "default_batch" ): r""" The call function to the pipeline for generation. @@ -1441,65 +1433,17 @@ def __call__( slice(-latent_window_size, None), ) - if vae_decode_type == "default": - current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean - current_video = self.vae.decode(current_latents, return_dict=False)[0] + current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean + current_video = self.vae.decode(current_latents, return_dict=False)[0] - if history_video is None: - history_video = current_video - else: - history_video = torch.cat([history_video, current_video], dim=2) + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) self._current_timestep = None if output_type != "latent": - if vae_decode_type == "default_batch": - total_latent_frames = real_history_latents.shape[2] - batch_size = real_history_latents.shape[0] - num_chunks = total_latent_frames // latent_window_size - - chunks = ( - real_history_latents.reshape( - batch_size, - -1, - num_chunks, - latent_window_size, - real_history_latents.shape[-2], - real_history_latents.shape[-1], - ) - .permute(0, 2, 1, 3, 4, 5) - .reshape( - batch_size * num_chunks, - -1, - latent_window_size, - real_history_latents.shape[-2], - real_history_latents.shape[-1], - ) - ) - - chunks = chunks.to(vae_dtype) / latents_std + latents_mean - batch_video = self.vae.decode(chunks, return_dict=False)[0] - - video_frames_per_chunk = batch_video.shape[2] - history_video = ( - batch_video.reshape( - batch_size, - num_chunks, - -1, - video_frames_per_chunk, - batch_video.shape[-2], - batch_video.shape[-1], - ) - .permute(0, 2, 1, 3, 4, 5) - .reshape( - batch_size, - -1, - num_chunks * video_frames_per_chunk, - batch_video.shape[-2], - batch_video.shape[-1], - ) - ) - generated_frames = history_video.size(2) generated_frames = ( generated_frames - 1 From 2683ca03e883c82c526f5c3123fc3cb2b191c862 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 07:28:19 +0000 Subject: [PATCH 025/107] move some validation into check_inputs() --- .../pipelines/helios/pipeline_helios.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 2d8ed2a29a19..c564e0c6b6e4 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -358,6 +358,10 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + use_interpolate_prompt=False, + num_videos_per_prompt=None, + interpolate_time_list=None, + interpolation_steps=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -390,6 +394,16 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if use_interpolate_prompt: + assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" + assert isinstance(prompt, list), "prompt must be a list" + assert len(prompt) == len(interpolate_time_list), ( + f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" + ) + assert min(interpolate_time_list) > interpolation_steps, ( + f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" + ) + def prepare_latents( self, batch_size: int, @@ -490,7 +504,7 @@ def interpolate_prompt_embeds( self, prompt_embeds_1: torch.Tensor, prompt_embeds_2: torch.Tensor, - interpolation_steps: int = 4, + interpolation_steps: int = 3, ): x = torch.lerp( prompt_embeds_1, @@ -1001,19 +1015,6 @@ def __call__( if image is not None and video is not None: raise ValueError("image and video cannot be provided simultaneously") - if use_interpolate_prompt: - assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" - assert isinstance(prompt, list), "prompt must be a list" - assert len(prompt) == len(interpolate_time_list), ( - f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" - ) - assert min(interpolate_time_list) > interpolation_steps, ( - f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" - ) - interpolate_interval_idx = None - interpolate_embeds = None - interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - history_sizes = sorted(history_sizes, reverse=True) # From big to small latents_mean = ( @@ -1037,6 +1038,10 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + use_interpolate_prompt, + num_videos_per_prompt, + interpolate_time_list, + interpolation_steps, ) num_frames = max(num_frames, 1) @@ -1058,6 +1063,11 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt + if use_interpolate_prompt: + interpolate_interval_idx = None + interpolate_embeds = None + interpolate_cumulative_list = list(accumulate(interpolate_time_list)) + all_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( self.encode_prompt( prompt=prompt, From 2f72a8ca6f11b4981e1529c93f79586f1dac1dda Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 08:12:16 +0000 Subject: [PATCH 026/107] rename helios scheduler & merge all into one step() --- 0_temp_helios_test/infer_helios.py | 4 +- docs/source/en/_toctree.yml | 4 +- .../schedulers/{helios_unipc.md => helios.md} | 10 +-- src/diffusers/__init__.py | 4 +- .../pipelines/helios/pipeline_helios.py | 9 +-- src/diffusers/schedulers/__init__.py | 4 +- ...g_helios_unipc.py => scheduling_helios.py} | 67 ++++++++++--------- src/diffusers/utils/dummy_pt_objects.py | 2 +- 8 files changed, 52 insertions(+), 52 deletions(-) rename docs/source/en/api/schedulers/{helios_unipc.md => helios.md} (71%) rename src/diffusers/schedulers/{scheduling_helios_unipc.py => scheduling_helios.py} (94%) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 7e1d17b25950..4602f4237898 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -14,7 +14,7 @@ from diffusers import HeliosTransformer3DModel from diffusers import HeliosPipeline -from diffusers.schedulers.scheduling_helios_unipc import HeliosUniPCScheduler +from diffusers.schedulers.scheduling_helios import HeliosScheduler from diffusers import ContextParallelConfig from diffusers.models import AutoencoderKLWan @@ -208,7 +208,7 @@ def main(): torch_dtype=torch.float32, ) if args.is_enable_stage2: - scheduler = HeliosUniPCScheduler( + scheduler = HeliosScheduler( shift=args.stage2_timestep_shift, stages=args.stage2_num_stages, stage_range=args.stage2_stage_range, diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 113d95def348..d75218af0da8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -750,8 +750,8 @@ title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler - - local: api/schedulers/helios_unipc - title: HeliosUniPCScheduler + - local: api/schedulers/helios + title: HeliosScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/schedulers/helios_unipc.md b/docs/source/en/api/schedulers/helios.md similarity index 71% rename from docs/source/en/api/schedulers/helios_unipc.md rename to docs/source/en/api/schedulers/helios.md index 97691b42479e..14c2be60bc89 100644 --- a/docs/source/en/api/schedulers/helios_unipc.md +++ b/docs/source/en/api/schedulers/helios.md @@ -10,11 +10,11 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# HeliosUniPCScheduler +# HeliosScheduler -`HeliosUniPCScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). +`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). -## HeliosUniPCScheduler -[[autodoc]] HeliosUniPCScheduler +## HeliosScheduler +[[autodoc]] HeliosScheduler -scheduling_helios_unipc +scheduling_helios diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a8c5587bc531..012e3d27b97c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -360,7 +360,7 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", - "HeliosUniPCScheduler", + "HeliosScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -1126,7 +1126,7 @@ FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, - HeliosUniPCScheduler, + HeliosScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index c564e0c6b6e4..bfcd3666586c 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput from ...loaders import HeliosLoraLoaderMixin from ...models import AutoencoderKLWan, HeliosTransformer3DModel -from ...schedulers import HeliosUniPCScheduler, UniPCMultistepScheduler +from ...schedulers import HeliosScheduler, UniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -210,7 +210,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, - scheduler: UniPCMultistepScheduler | HeliosUniPCScheduler, + scheduler: UniPCMultistepScheduler | HeliosScheduler, transformer: HeliosTransformer3DModel, ): super().__init__() @@ -838,10 +838,7 @@ def stage2_sample( else: latents = pred_image_or_video else: - if scheduler_type == "unipc": - latents = self.scheduler.step_unipc(noise_pred.float(), t, latents, return_dict=False)[0] - else: - latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index add9650488d4..56c58dbeb069 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,7 +61,7 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] - _import_structure["scheduling_helios_unipc"] = ["HeliosUniPCScheduler"] + _import_structure["scheduling_helios"] = ["HeliosScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -165,7 +165,7 @@ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler - from .scheduling_helios_unipc import HeliosUniPCScheduler + from .scheduling_helios import HeliosScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_helios_unipc.py b/src/diffusers/schedulers/scheduling_helios.py similarity index 94% rename from src/diffusers/schedulers/scheduling_helios_unipc.py rename to src/diffusers/schedulers/scheduling_helios.py index 28dc5aa76f26..3d4aa5475e6b 100644 --- a/src/diffusers/schedulers/scheduling_helios_unipc.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -11,14 +11,14 @@ @dataclass -class HeliosUniPCSchedulerOutput(BaseOutput): +class HeliosSchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor model_outputs: torch.FloatTensor last_sample: torch.FloatTensor this_order: int -class HeliosUniPCScheduler(SchedulerMixin, ConfigMixin): +class HeliosScheduler(SchedulerMixin, ConfigMixin): _compatibles = [] order = 1 @@ -41,6 +41,7 @@ def __init__( solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, version: str = "v1", + scheduler_type: str = "unipc", ): self.version = version self.timestep_ratios = {} # The timestep ratio for each stage @@ -236,6 +237,7 @@ def set_timesteps( self._step_index = None self.reset_scheduler_history() + # ---------------------------------- Euler ---------------------------------- def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -258,7 +260,7 @@ def _init_step_index(self, timestep): else: self._step_index = self._begin_index - def step( + def step_euler( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] = None, @@ -267,30 +269,7 @@ def step( sigma: Optional[torch.FloatTensor] = None, sigma_next: Optional[torch.FloatTensor] = None, return_dict: bool = True, - ) -> Union[HeliosUniPCSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or - tuple. - - Returns: - [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. - """ - + ) -> Union[HeliosSchedulerOutput, Tuple]: assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None" if sigma is None and sigma_next is None: @@ -328,7 +307,7 @@ def step( if not return_dict: return (prev_sample,) - return HeliosUniPCSchedulerOutput(prev_sample=prev_sample) + return HeliosSchedulerOutput(prev_sample=prev_sample) # ---------------------------------- UniPC ---------------------------------- def _sigma_to_alpha_sigma_t(self, sigma): @@ -712,7 +691,7 @@ def step_unipc( cus_lower_order_num: int = None, cus_this_order: int = None, cus_last_sample: torch.Tensor = None, - ) -> Union[HeliosUniPCSchedulerOutput, Tuple]: + ) -> Union[HeliosSchedulerOutput, Tuple]: if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" @@ -744,8 +723,6 @@ def step_unipc( self.model_outputs = model_outputs[:-1] self.timestep_list = timestep_list[:-1] - # print("1", self.step_index, self.timestep_list) - if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, @@ -794,13 +771,39 @@ def step_unipc( if not return_dict: return (prev_sample, model_outputs, self.last_sample, self.this_order) - return HeliosUniPCSchedulerOutput( + return HeliosSchedulerOutput( prev_sample=prev_sample, model_outputs=model_outputs, last_sample=self.last_sample, this_order=self.this_order, ) + # ---------------------------------- Merge ---------------------------------- + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor] = None, + sample: torch.FloatTensor = None, + return_dict: bool = True, + scheduler_type: str = "unipc", + ) -> Union[HeliosSchedulerOutput, Tuple]: + if scheduler_type == "unipc": + self.step_euler( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + ) + elif scheduler_type == "euler": + self.step_unipc( + model_output=model_output, + timestep=timestep, + sample=sample, + return_dict=return_dict, + ) + else: + raise NotImplementedError + def reset_scheduler_history(self): self.model_outputs = [None] * self.config.solver_order self.timestep_list = [None] * self.config.solver_order diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 55ed0c212d7a..6c12a4f04d2f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2758,7 +2758,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HeliosUniPCScheduler(metaclass=DummyObject): +class HeliosScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 9413826ea4c5280960df0588161a3fe4aea6185f Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 08:19:12 +0000 Subject: [PATCH 027/107] add some details to doc --- docs/source/en/api/pipelines/helios.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index a8f2c14104db..418aaced1fb0 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -28,9 +28,9 @@ The following Helios models are supported in Diffusers: -- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality -- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight -- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency +- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and UniPCMultistepScheduler. +- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler. +- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosScheduler. > [!TIP] > Click on the Helios models in the right sidebar for more examples of video generation. From 3bd8b865791e74c7fe04174a1a67d3504098e9d4 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 09:16:27 +0000 Subject: [PATCH 028/107] move dmd step() logics from pipeline to scheduler --- 0_temp_helios_test/infer_helios.py | 5 - .../pipelines/helios/pipeline_helios.py | 104 ++++++--------- src/diffusers/schedulers/scheduling_helios.py | 120 +++++++++++++++--- 3 files changed, 141 insertions(+), 88 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 4602f4237898..b06f86ac5d1a 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -300,7 +300,6 @@ def parse_list_input(input_string): is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - scheduler_type="unipc", # stage 3 use_dmd=args.is_enable_stage3, is_skip_first_section=args.is_skip_first_section, @@ -362,7 +361,6 @@ def parse_list_input(input_string): is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - scheduler_type="unipc", # stage 3 use_dmd=args.is_enable_stage3, is_skip_first_section=args.is_skip_first_section, @@ -432,7 +430,6 @@ def parse_list_input(input_string): is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - scheduler_type="unipc", # stage 3 use_dmd=args.is_enable_stage3, is_skip_first_section=args.is_skip_first_section, @@ -491,7 +488,6 @@ def parse_list_input(input_string): is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - scheduler_type="unipc", # stage 3 use_dmd=args.is_enable_stage3, is_skip_first_section=args.is_skip_first_section, @@ -535,7 +531,6 @@ def parse_list_input(input_string): is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - scheduler_type="unipc", # stage 3 use_dmd=args.is_enable_stage3, is_skip_first_section=args.is_skip_first_section, diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index bfcd3666586c..06242016aeeb 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -158,27 +158,6 @@ def apply_schedule_shift( return sigmas -def add_noise(original_samples, noise, timestep, sigmas, timesteps): - sigmas = sigmas.to(noise.device) - timesteps = timesteps.to(noise.device) - timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) - sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) - sample = (1 - sigma) * original_samples + sigma * noise - return sample.type_as(noise) - - -def convert_flow_pred_to_x0(flow_pred, xt, timestep, sigmas, timesteps): - # use higher precision for calculations - original_dtype = flow_pred.dtype - device = flow_pred.device - flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) - - timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) - sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) - x0_pred = xt - sigma_t * flow_pred - return x0_pred.to(original_dtype) - - class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): r""" Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. @@ -542,7 +521,6 @@ def stage1_sample( attention_kwargs: Optional[dict] = None, device: Optional[torch.device] = None, transformer_dtype: torch.dtype = None, - scheduler_type: str = "unipc", use_dynamic_shifting: bool = False, generator: Optional[torch.Generator] = None, # ------------ CFG Zero ------------ @@ -618,27 +596,25 @@ def stage1_sample( else: noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - if use_dmd: - pred_image_or_video = convert_flow_pred_to_x0( - flow_pred=noise_pred, - xt=latent_model_input, - timestep=t * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), - sigmas=dmd_sigmas, - timesteps=dmd_timesteps, - ) - if i < len(timesteps) - 1: - latents = add_noise( - pred_image_or_video, - randn_tensor(pred_image_or_video.shape, generator=generator, device=device), - timesteps[i + 1] * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), - sigmas=dmd_sigmas, - timesteps=dmd_timesteps, - ) - else: - latents = pred_image_or_video + if isinstance(self.scheduler, HeliosScheduler): + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + generator=generator, + cur_sampling_step=i, + dmd_sigmas=dmd_sigmas, + dmd_timesteps=dmd_timesteps, + all_timesteps=timesteps, + )[0] else: - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -675,8 +651,8 @@ def stage2_sample( attention_kwargs: Optional[dict] = None, device: Optional[torch.device] = None, transformer_dtype: torch.dtype = None, - scheduler_type: str = "unipc", # unipc, euler use_dynamic_shifting: bool = False, + generator: Optional[torch.Generator] = None, # ------------ CFG Zero ------------ use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, @@ -819,26 +795,25 @@ def stage2_sample( else: noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - if use_dmd: - pred_image_or_video = convert_flow_pred_to_x0( - flow_pred=noise_pred, - xt=latents, - timestep=timestep, - sigmas=self.scheduler.sigmas, - timesteps=self.scheduler.timesteps, - ) - if idx < len(timesteps) - 1: - latents = add_noise( - pred_image_or_video, - start_point_list[i_s], - timesteps[idx + 1] * torch.ones(batch_size, dtype=torch.long, device=noise_pred.device), - sigmas=self.scheduler.sigmas, - timesteps=self.scheduler.timesteps, - ) - else: - latents = pred_image_or_video + if isinstance(self.scheduler, HeliosScheduler): + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + generator=generator, + cur_sampling_step=i, + dmd_sigmas=self.scheduler.sigmas, + dmd_timesteps=self.scheduler.timesteps, + all_timesteps=timesteps, + )[0] else: - latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -933,7 +908,6 @@ def __call__( is_enable_stage2: bool = False, stage2_num_stages: int = 3, stage2_num_inference_steps_list: list = [10, 10, 10], - scheduler_type: str = "unipc", # unipc, euler # ------------ CFG Zero ------------ use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, @@ -1377,7 +1351,6 @@ def __call__( attention_kwargs=attention_kwargs, device=device, transformer_dtype=transformer_dtype, - scheduler_type=scheduler_type, use_dynamic_shifting=use_dynamic_shifting, # ------------ CFG Zero ------------ use_cfg_zero_star=use_cfg_zero_star, @@ -1408,7 +1381,6 @@ def __call__( attention_kwargs=attention_kwargs, device=device, transformer_dtype=transformer_dtype, - scheduler_type=scheduler_type, use_dynamic_shifting=use_dynamic_shifting, generator=generator, # ------------ CFG Zero ------------ diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 3d4aa5475e6b..12788da1bc7e 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -1,13 +1,28 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional import numpy as np import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils import BaseOutput, deprecate +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate +from ..utils.torch_utils import randn_tensor @dataclass @@ -41,7 +56,7 @@ def __init__( solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, version: str = "v1", - scheduler_type: str = "unipc", + scheduler_type: str = "unipc", # ["euler", "unipc", "dmd"] ): self.version = version self.timestep_ratios = {} # The timestep ratio for each stage @@ -192,7 +207,7 @@ def set_timesteps( self, num_inference_steps: int, stage_index: int, - device: Union[str, torch.device] = None, + device: str | torch.device = None, ): """ Setting the timesteps and sigmas for each stage @@ -263,13 +278,13 @@ def _init_step_index(self, timestep): def step_euler( self, model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor] = None, + timestep: float | torch.FloatTensor = None, sample: torch.FloatTensor = None, generator: Optional[torch.Generator] = None, sigma: Optional[torch.FloatTensor] = None, sigma_next: Optional[torch.FloatTensor] = None, return_dict: bool = True, - ) -> Union[HeliosSchedulerOutput, Tuple]: + ) -> HeliosSchedulerOutput | tuple: assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None" if sigma is None and sigma_next is None: @@ -679,7 +694,7 @@ def multistep_uni_c_bh_update( def step_unipc( self, model_output: torch.Tensor, - timestep: Union[int, torch.Tensor] = None, + timestep: int | torch.Tensor = None, sample: torch.Tensor = None, return_dict: bool = True, model_outputs: list = None, @@ -691,7 +706,7 @@ def step_unipc( cus_lower_order_num: int = None, cus_this_order: int = None, cus_last_sample: torch.Tensor = None, - ) -> Union[HeliosSchedulerOutput, Tuple]: + ) -> HeliosSchedulerOutput | tuple: if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" @@ -778,29 +793,100 @@ def step_unipc( this_order=self.this_order, ) + # ---------------------------------- For DMD ---------------------------------- + def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): + # use higher precision for calculations + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + def step_dmd( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + cur_sampling_step: int = 0, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ): + prev_sample = self.convert_flow_pred_to_x0( + flow_pred=model_output, + xt=sample, + timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if cur_sampling_step < len(all_timesteps) - 1: + prev_sample = self.add_noise( + prev_sample, + randn_tensor(prev_sample.shape, generator=generator, device=model_output.device), + torch.full( + (model_output.shape[0],), + all_timesteps[cur_sampling_step + 1], + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + + return HeliosSchedulerOutput(prev_sample=prev_sample) + # ---------------------------------- Merge ---------------------------------- def step( self, model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor] = None, + timestep: float | torch.FloatTensor = None, sample: torch.FloatTensor = None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, - scheduler_type: str = "unipc", - ) -> Union[HeliosSchedulerOutput, Tuple]: - if scheduler_type == "unipc": - self.step_euler( + # For DMD + cur_sampling_step: int = 0, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ) -> HeliosSchedulerOutput | tuple: + if self.config.scheduler_type == "euler": + return self.step_euler( model_output=model_output, timestep=timestep, sample=sample, + generator=generator, return_dict=return_dict, ) - elif scheduler_type == "euler": - self.step_unipc( + elif self.config.scheduler_type == "unipc": + return self.step_unipc( model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict, ) + elif self.config.scheduler_type == "dmd": + self.step_dmd( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + cur_sampling_step=cur_sampling_step, + dmd_sigmas=dmd_sigmas, + dmd_timesteps=dmd_timesteps, + all_timesteps=all_timesteps, + ) else: raise NotImplementedError From f5fe040d35e496d02acc728cdc8c569d6d7eb5b5 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 09:38:38 +0000 Subject: [PATCH 029/107] change to Python 3.9+ style type --- .../models/transformers/transformer_helios.py | 30 ++-- .../pipelines/helios/pipeline_helios.py | 128 +++++++++--------- src/diffusers/schedulers/scheduling_helios.py | 13 +- 3 files changed, 84 insertions(+), 87 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index e7ebe7c1fff3..1b506231c356 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -14,7 +14,7 @@ import math from functools import lru_cache -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import torch import torch.nn as nn @@ -96,9 +96,9 @@ def __call__( self, attn: "HeliosAttention", hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, original_context_length: int = None, ) -> torch.Tensor: query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) @@ -153,8 +153,8 @@ def __init__( dim_head: int = 64, eps: float = 1e-5, dropout: float = 0.0, - added_kv_proj_dim: Optional[int] = None, - cross_attention_dim_head: Optional[int] = None, + added_kv_proj_dim: int | None = None, + cross_attention_dim_head: int | None = None, processor=None, is_cross_attention=None, is_amplify_history=False, @@ -256,9 +256,9 @@ def unfuse_projections(self): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, original_context_length: int = None, **kwargs, ) -> torch.Tensor: @@ -292,7 +292,7 @@ def __init__( def forward( self, timestep: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states: torch.Tensor | None = None, is_return_encoder_hidden_states: bool = True, ): timestep = self.timesteps_proj(timestep) @@ -367,10 +367,10 @@ def __init__( qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, - added_kv_proj_dim: Optional[int] = None, + added_kv_proj_dim: int | None = None, guidance_cross_attn: bool = False, is_amplify_history: bool = False, - history_scale_mode: str = "per_head", # [scalar, per_head], + history_scale_mode: str = "per_head", # [scalar, per_head] ): super().__init__() @@ -489,7 +489,7 @@ class HeliosTransformer3DModel( A Transformer model for video-like data used in the Helios model. Args: - patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). num_attention_heads (`int`, defaults to `40`): Fixed length for text embeddings. @@ -507,7 +507,7 @@ class HeliosTransformer3DModel( Intermediate dimension in feed-forward network. num_layers (`int`, defaults to `40`): The number of layers of transformer blocks to use. - window_size (`Tuple[int]`, defaults to `(-1, -1)`): + window_size (`tuple[int]`, defaults to `(-1, -1)`): Window size for local attention (-1 indicates global attention). cross_attn_norm (`bool`, defaults to `True`): Enable cross-attention normalization. @@ -636,7 +636,7 @@ def forward( latents_history_long=None, return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, - ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> torch.Tensor | dict[str, torch.Tensor]: assert ( len( { diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 06242016aeeb..858e0a57f1a6 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -15,7 +15,7 @@ import html import math from itertools import accumulate -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable import regex as re import torch @@ -207,11 +207,11 @@ def __init__( def _get_t5_prompt_embeds( self, - prompt: Union[str, List[str]] = None, + prompt: str | list[str] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -252,19 +252,19 @@ def encode_prompt( negative_prompt: str | list[str] | None = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `list[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): + negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). @@ -390,10 +390,10 @@ def prepare_latents( height: int = 384, width: int = 640, num_frames: int = 33, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) @@ -420,11 +420,11 @@ def prepare_image_latents( image: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - fake_latents: Optional[torch.Tensor] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, ) -> torch.Tensor: device = device or self._execution_device if latents is None: @@ -444,10 +444,10 @@ def prepare_video_latents( latents_mean: torch.Tensor, latents_std: torch.Tensor, latent_window_size: int, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, ) -> torch.Tensor: device = device or self._execution_device video = video.to(device=device, dtype=self.vae.dtype) @@ -510,7 +510,7 @@ def stage1_sample( prompt_embeds: torch.Tensor = None, negative_prompt_embeds: torch.Tensor = None, timesteps: torch.Tensor = None, - guidance_scale: Optional[float] = 5.0, + guidance_scale: float | None = 5.0, indices_hidden_states: torch.Tensor = None, indices_latents_history_short: torch.Tensor = None, indices_latents_history_mid: torch.Tensor = None, @@ -518,22 +518,22 @@ def stage1_sample( latents_history_short: torch.Tensor = None, latents_history_mid: torch.Tensor = None, latents_history_long: torch.Tensor = None, - attention_kwargs: Optional[dict] = None, - device: Optional[torch.device] = None, + attention_kwargs: dict | None = None, + device: torch.device | None = None, transformer_dtype: torch.dtype = None, use_dynamic_shifting: bool = False, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, # ------------ CFG Zero ------------ - use_cfg_zero_star: Optional[bool] = False, - use_zero_init: Optional[bool] = True, - zero_steps: Optional[int] = 1, + use_cfg_zero_star: bool | None = False, + use_zero_init: bool | None = True, + zero_steps: int | None = 1, # -------------- DMD -------------- use_dmd: bool = False, dmd_sigmas: torch.Tensor = None, dmd_timesteps: torch.Tensor = None, is_amplify_first_chunk: bool = False, # ------------ Callback ------------ - callback_on_step_end: Optional[callable] = None, + callback_on_step_end: callable | None = None, callback_on_step_end_tensor_inputs: list = None, progress_bar=None, ): @@ -637,10 +637,10 @@ def stage2_sample( self, latents: torch.Tensor = None, stage2_num_stages: int = None, - stage2_num_inference_steps_list: List[int] = None, + stage2_num_inference_steps_list: list[int] = None, prompt_embeds: torch.Tensor = None, negative_prompt_embeds: torch.Tensor = None, - guidance_scale: Optional[float] = 5.0, + guidance_scale: float | None = 5.0, indices_hidden_states: torch.Tensor = None, indices_latents_history_short: torch.Tensor = None, indices_latents_history_mid: torch.Tensor = None, @@ -648,20 +648,20 @@ def stage2_sample( latents_history_short: torch.Tensor = None, latents_history_mid: torch.Tensor = None, latents_history_long: torch.Tensor = None, - attention_kwargs: Optional[dict] = None, - device: Optional[torch.device] = None, + attention_kwargs: dict | None = None, + device: torch.device | None = None, transformer_dtype: torch.dtype = None, use_dynamic_shifting: bool = False, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, # ------------ CFG Zero ------------ - use_cfg_zero_star: Optional[bool] = False, - use_zero_init: Optional[bool] = True, - zero_steps: Optional[int] = 1, + use_cfg_zero_star: bool | None = False, + use_zero_init: bool | None = True, + zero_steps: int | None = 1, # -------------- DMD -------------- use_dmd: bool = False, is_amplify_first_chunk: bool = False, # ------------ Callback ------------ - callback_on_step_end: Optional[callable] = None, + callback_on_step_end: callable | None = None, callback_on_step_end_tensor_inputs: list = None, progress_bar=None, ): @@ -862,36 +862,34 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, height: int = 384, width: int = 640, num_frames: int = 132, num_inference_steps: int = 50, guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = "np", + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, # ------------ I2V ------------ - image: Optional[PipelineImageInput] = None, - image_latents: Optional[torch.Tensor] = None, - fake_image_latents: Optional[torch.Tensor] = None, + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, add_noise_to_image_latents: bool = True, image_noise_sigma_min: float = 0.111, image_noise_sigma_max: float = 0.135, # ------------ V2V ------------ - video: Optional[PipelineImageInput] = None, - video_latents: Optional[torch.Tensor] = None, + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, add_noise_to_video_latents: bool = True, video_noise_sigma_min: float = 0.111, video_noise_sigma_max: float = 0.135, @@ -909,9 +907,9 @@ def __call__( stage2_num_stages: int = 3, stage2_num_inference_steps_list: list = [10, 10, 10], # ------------ CFG Zero ------------ - use_cfg_zero_star: Optional[bool] = False, - use_zero_init: Optional[bool] = True, - zero_steps: Optional[int] = 1, + use_cfg_zero_star: bool | None = False, + use_zero_init: bool | None = True, + zero_steps: int | None = 1, # ------------ DMD ------------ use_dmd: bool = False, is_skip_first_section: bool = False, @@ -921,9 +919,9 @@ def __call__( The call function to the pipeline for generation. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `list[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. - negative_prompt (`str` or `List[str]`, *optional*): + negative_prompt (`str` or `list[str]`, *optional*): The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `384`): @@ -943,7 +941,7 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): @@ -966,7 +964,7 @@ def __call__( each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): + callback_on_step_end_tensor_inputs (`list`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 12788da1bc7e..8f9c078495e9 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -14,7 +14,6 @@ import math from dataclasses import dataclass -from typing import List, Optional import numpy as np import torch @@ -43,7 +42,7 @@ def __init__( num_train_timesteps: int = 1000, shift: float = 1.0, # Following Stable diffusion 3, stages: int = 3, - stage_range: List = [0, 1 / 3, 2 / 3, 1], + stage_range: list = [0, 1 / 3, 2 / 3, 1], gamma: float = 1 / 3, # For UniPC thresholding: bool = False, @@ -52,7 +51,7 @@ def __init__( predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, - disable_corrector: List[int] = [], + disable_corrector: list[int] = [], solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, version: str = "v1", @@ -280,9 +279,9 @@ def step_euler( model_output: torch.FloatTensor, timestep: float | torch.FloatTensor = None, sample: torch.FloatTensor = None, - generator: Optional[torch.Generator] = None, - sigma: Optional[torch.FloatTensor] = None, - sigma_next: Optional[torch.FloatTensor] = None, + generator: torch.Generator | None = None, + sigma: torch.FloatTensor | None = None, + sigma_next: torch.FloatTensor | None = None, return_dict: bool = True, ) -> HeliosSchedulerOutput | tuple: assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None" @@ -853,7 +852,7 @@ def step( model_output: torch.FloatTensor, timestep: float | torch.FloatTensor = None, sample: torch.FloatTensor = None, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, return_dict: bool = True, # For DMD cur_sampling_step: int = 0, From 1572cf06ba87b0371cc7cf03bd75f1a1e8e56fe7 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 09:43:31 +0000 Subject: [PATCH 030/107] fix NoneType error --- src/diffusers/pipelines/helios/pipeline_helios.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 858e0a57f1a6..992af0d927e1 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -533,8 +533,8 @@ def stage1_sample( dmd_timesteps: torch.Tensor = None, is_amplify_first_chunk: bool = False, # ------------ Callback ------------ - callback_on_step_end: callable | None = None, - callback_on_step_end_tensor_inputs: list = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], progress_bar=None, ): batch_size = latents.shape[0] @@ -661,8 +661,8 @@ def stage2_sample( use_dmd: bool = False, is_amplify_first_chunk: bool = False, # ------------ Callback ------------ - callback_on_step_end: callable | None = None, - callback_on_step_end_tensor_inputs: list = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], progress_bar=None, ): batch_size, num_channel, num_frmaes, height, width = latents.shape From 3b78e9b489eec6954822a34957a5ce17a5d614b8 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 13:54:07 +0000 Subject: [PATCH 031/107] refactor DMD scheduler's set_timestep --- 0_temp_helios_test/infer_helios.py | 27 ++++------------- .../pipelines/helios/pipeline_helios.py | 18 +++++------ src/diffusers/schedulers/scheduling_helios.py | 30 +++++++++++++++---- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index b06f86ac5d1a..a4a730a2dd53 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -207,27 +207,12 @@ def main(): subfolder="vae", torch_dtype=torch.float32, ) - if args.is_enable_stage2: - scheduler = HeliosScheduler( - shift=args.stage2_timestep_shift, - stages=args.stage2_num_stages, - stage_range=args.stage2_stage_range, - gamma=args.stage2_scheduler_gamma, - ) - pipe = HeliosPipeline.from_pretrained( - args.base_model_path, - transformer=transformer, - vae=vae, - scheduler=scheduler, - torch_dtype=args.weight_dtype, - ) - else: - pipe = HeliosPipeline.from_pretrained( - args.base_model_path, - transformer=transformer, - vae=vae, - torch_dtype=args.weight_dtype, - ) + pipe = HeliosPipeline.from_pretrained( + args.base_model_path, + transformer=transformer, + vae=vae, + torch_dtype=args.weight_dtype, + ) if args.lora_path is not None: pipe.load_lora_weights(args.lora_path, adapter_name="default") diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 992af0d927e1..1df8772481ab 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -601,9 +601,10 @@ def stage1_sample( noise_pred, t, latents, - return_dict=False, generator=generator, + return_dict=False, cur_sampling_step=i, + dmd_noisy_tensor=randn_tensor(noise_pred.shape, generator=generator, device=device), dmd_sigmas=dmd_sigmas, dmd_timesteps=dmd_timesteps, all_timesteps=timesteps, @@ -686,15 +687,9 @@ def stage2_sample( i = 0 for i_s in range(stage2_num_stages): - if use_dmd: - if is_amplify_first_chunk: - self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s] * 2 + 1, i_s, device=device) - else: - self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s] + 1, i_s, device=device) - self.scheduler.timesteps = self.scheduler.timesteps[:-1] - self.scheduler.sigmas = torch.cat([self.scheduler.sigmas[:-2], self.scheduler.sigmas[-1:]]) - else: - self.scheduler.set_timesteps(stage2_num_inference_steps_list[i_s], i_s, device=device) + self.scheduler.set_timesteps( + stage2_num_inference_steps_list[i_s], i_s, device=device, is_amplify_first_chunk=is_amplify_first_chunk + ) if i_s > 0: height *= 2 @@ -800,9 +795,10 @@ def stage2_sample( noise_pred, t, latents, - return_dict=False, generator=generator, + return_dict=False, cur_sampling_step=i, + dmd_noisy_tensor=start_point_list[i_s], dmd_sigmas=self.scheduler.sigmas, dmd_timesteps=self.scheduler.timesteps, all_timesteps=timesteps, diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 8f9c078495e9..28cf2e2b19a1 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -21,15 +21,14 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..schedulers.scheduling_utils import SchedulerMixin from ..utils import BaseOutput, deprecate -from ..utils.torch_utils import randn_tensor @dataclass class HeliosSchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor - model_outputs: torch.FloatTensor - last_sample: torch.FloatTensor - this_order: int + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None class HeliosScheduler(SchedulerMixin, ConfigMixin): @@ -207,10 +206,17 @@ def set_timesteps( num_inference_steps: int, stage_index: int, device: str | torch.device = None, + is_amplify_first_chunk: bool = False, ): """ Setting the timesteps and sigmas for each stage """ + if self.config.scheduler_type == "dmd": + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + self.num_inference_steps = num_inference_steps self.init_sigmas() @@ -251,6 +257,10 @@ def set_timesteps( self._step_index = None self.reset_scheduler_history() + if self.config.scheduler_type == "dmd": + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + # ---------------------------------- Euler ---------------------------------- def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: @@ -818,7 +828,9 @@ def step_dmd( timestep: float | torch.FloatTensor = None, sample: torch.FloatTensor = None, generator: torch.Generator | None = None, + return_dict: bool = True, cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, dmd_sigmas: torch.FloatTensor | None = None, dmd_timesteps: torch.FloatTensor | None = None, all_timesteps: torch.FloatTensor | None = None, @@ -833,7 +845,7 @@ def step_dmd( if cur_sampling_step < len(all_timesteps) - 1: prev_sample = self.add_noise( prev_sample, - randn_tensor(prev_sample.shape, generator=generator, device=model_output.device), + dmd_noisy_tensor, torch.full( (model_output.shape[0],), all_timesteps[cur_sampling_step + 1], @@ -844,6 +856,9 @@ def step_dmd( timesteps=dmd_timesteps, ) + if not return_dict: + return (prev_sample,) + return HeliosSchedulerOutput(prev_sample=prev_sample) # ---------------------------------- Merge ---------------------------------- @@ -856,6 +871,7 @@ def step( return_dict: bool = True, # For DMD cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, dmd_sigmas: torch.FloatTensor | None = None, dmd_timesteps: torch.FloatTensor | None = None, all_timesteps: torch.FloatTensor | None = None, @@ -876,12 +892,14 @@ def step( return_dict=return_dict, ) elif self.config.scheduler_type == "dmd": - self.step_dmd( + return self.step_dmd( model_output=model_output, timestep=timestep, sample=sample, generator=generator, + return_dict=return_dict, cur_sampling_step=cur_sampling_step, + dmd_noisy_tensor=dmd_noisy_tensor, dmd_sigmas=dmd_sigmas, dmd_timesteps=dmd_timesteps, all_timesteps=all_timesteps, From ba4d2033b7b16f8fb273612269a65b223fbf71ad Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 14:03:08 +0000 Subject: [PATCH 032/107] change rope related vars name --- .../models/transformers/transformer_helios.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 1b506231c356..9bf6301f9991 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -331,28 +331,28 @@ def get_frequency_batched(self, freqs_base, pos): @lru_cache(maxsize=32) def _get_spatial_meshgrid(self, height, width, device_str): device = torch.device(device_str) - gy = torch.arange(height, device=device, dtype=torch.float32) - gx = torch.arange(width, device=device, dtype=torch.float32) - GY, GX = torch.meshgrid(gy, gx, indexing="ij") - return GY, GX + grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) + grid_x_coords = torch.arange(width, device=device, dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij") + return grid_y, grid_x @torch.no_grad() def forward(self, frame_indices, height, width, device): - B = frame_indices.shape[0] - T = frame_indices.shape[1] + batch_size = frame_indices.shape[0] + num_frames = frame_indices.shape[1] frame_indices = frame_indices.to(device=device, dtype=torch.float32) - GY, GX = self._get_spatial_meshgrid(height, width, str(device)) + grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device)) - GT = frame_indices[:, :, None, None].expand(B, T, height, width) - GY_batch = GY[None, None, :, :].expand(B, T, -1, -1) - GX_batch = GX[None, None, :, :].expand(B, T, -1, -1) + grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width) + grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1) + grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1) - FCT, FST = self.get_frequency_batched(self.freqs_base_t, GT) - FCY, FSY = self.get_frequency_batched(self.freqs_base_y, GY_batch) - FCX, FSX = self.get_frequency_batched(self.freqs_base_x, GX_batch) + freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t) + freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch) + freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch) - result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0) return result.permute(1, 0, 2, 3, 4) From 4bdcc0a3a970fff7f8e990d06a2458fd598f9142 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 14:32:18 +0000 Subject: [PATCH 033/107] fix stage2 sample --- src/diffusers/pipelines/helios/pipeline_helios.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 1df8772481ab..f21aa9603d30 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -682,6 +682,7 @@ def stage2_sample( latents = latents.reshape(batch_size, num_frmaes, num_channel, height, width).permute(0, 2, 1, 3, 4) batch_size = latents.shape[0] + start_point_list = None if use_dmd: start_point_list = [latents] @@ -798,7 +799,7 @@ def stage2_sample( generator=generator, return_dict=False, cur_sampling_step=i, - dmd_noisy_tensor=start_point_list[i_s], + dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None, dmd_sigmas=self.scheduler.sigmas, dmd_timesteps=self.scheduler.timesteps, all_timesteps=timesteps, From 2f7bfc566fd82cec236f7306f723a4055c706f46 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Feb 2026 15:08:47 +0000 Subject: [PATCH 034/107] fix dmd sample --- src/diffusers/pipelines/helios/pipeline_helios.py | 2 +- src/diffusers/schedulers/scheduling_helios.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index f21aa9603d30..5c2b1ffb09d7 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -798,7 +798,7 @@ def stage2_sample( latents, generator=generator, return_dict=False, - cur_sampling_step=i, + cur_sampling_step=idx, dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None, dmd_sigmas=self.scheduler.sigmas, dmd_timesteps=self.scheduler.timesteps, diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 28cf2e2b19a1..e2bc80d00f99 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -835,7 +835,7 @@ def step_dmd( dmd_timesteps: torch.FloatTensor | None = None, all_timesteps: torch.FloatTensor | None = None, ): - prev_sample = self.convert_flow_pred_to_x0( + pred_image_or_video = self.convert_flow_pred_to_x0( flow_pred=model_output, xt=sample, timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), @@ -844,7 +844,7 @@ def step_dmd( ) if cur_sampling_step < len(all_timesteps) - 1: prev_sample = self.add_noise( - prev_sample, + pred_image_or_video, dmd_noisy_tensor, torch.full( (model_output.shape[0],), @@ -855,6 +855,8 @@ def step_dmd( sigmas=dmd_sigmas, timesteps=dmd_timesteps, ) + else: + prev_sample = pred_image_or_video if not return_dict: return (prev_sample,) From 90c355d0b1c3b5b438ceead00d1d46fad4d3d59a Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:37:29 +0800 Subject: [PATCH 035/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- .../models/transformers/transformer_helios.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 9bf6301f9991..f6352e2f62ca 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -637,23 +637,6 @@ def forward( return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor | dict[str, torch.Tensor]: - assert ( - len( - { - x is None - for x in [ - indices_hidden_states, - indices_latents_history_short, - indices_latents_history_mid, - indices_latents_history_long, - latents_history_short, - latents_history_mid, - latents_history_long, - ] - } - ) - == 1 - ), "All history latents and indices must either all exist or all be None" if indices_hidden_states is not None and indices_hidden_states.ndim == 1: indices_hidden_states = indices_hidden_states.unsqueeze(0) From 8256941e23f9b751d9174472fe4cbaa0f90c3b7b Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:38:06 +0800 Subject: [PATCH 036/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index f6352e2f62ca..6022f2272138 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -638,14 +638,6 @@ def forward( attention_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor | dict[str, torch.Tensor]: - if indices_hidden_states is not None and indices_hidden_states.ndim == 1: - indices_hidden_states = indices_hidden_states.unsqueeze(0) - if indices_latents_history_short is not None and indices_latents_history_short.ndim == 1: - indices_latents_history_short = indices_latents_history_short.unsqueeze(0) - if indices_latents_history_mid is not None and indices_latents_history_mid.ndim == 1: - indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) - if indices_latents_history_long is not None and indices_latents_history_long.ndim == 1: - indices_latents_history_long = indices_latents_history_long.unsqueeze(0) batch_size = hidden_states.shape[0] p_t, p_h, p_w = self.config.patch_size From 389c83021218e2bb3ef884bfff7615a00d49337c Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 03:45:24 +0000 Subject: [PATCH 037/107] remove redundant & refactor norm_out --- .../models/transformers/transformer_helios.py | 73 ++++++++----------- .../pipelines/helios/pipeline_helios.py | 16 ++-- 2 files changed, 38 insertions(+), 51 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 6022f2272138..a928b5eabf9e 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -82,6 +82,21 @@ def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, e return query, key, value +class HeliosOutputNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False): + super().__init__() + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int): + temb = temb[:, -original_context_length:, :] + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device) + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + return hidden_states + + class HeliosAttnProcessor: _attention_backend = None _parallel_config = None @@ -523,7 +538,7 @@ class HeliosTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["HeliosTransformerBlock"] + _no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"] _keep_in_fp32_modules = [ "time_embedder", "scale_shift_table", @@ -614,9 +629,8 @@ def __init__( ) # 5. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False @@ -637,18 +651,11 @@ def forward( return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor | dict[str, torch.Tensor]: - - + # 1. Input batch_size = hidden_states.shape[0] p_t, p_h, p_w = self.config.patch_size - # hidden: [high, mid, low] -> [low, mid, high] - post_patch_height_list = [] - post_patch_width_list = [] - post_patch_num_frames_list = [] - original_context_length_list = [] - - # Process noisy latents + # 2. Process noisy latents hidden_states = self.patch_embedding(hidden_states) B, C, T, H, W = hidden_states.shape @@ -664,12 +671,12 @@ def forward( ) rotary_emb = rotary_emb.flatten(2).transpose(1, 2) - post_patch_height_list.append(H) - post_patch_width_list.append(W) - post_patch_num_frames_list.append(T) - original_context_length_list.append(hidden_states.shape[1]) + post_patch_height = H + post_patch_width = W + post_patch_num_frames = T + original_context_length = hidden_states.shape[1] - # Process short history latents + # 3. Process short history latents if latents_history_short is not None and indices_latents_history_short is not None: latents_history_short = latents_history_short.to(hidden_states) latents_history_short = self.patch_short(latents_history_short) @@ -687,7 +694,7 @@ def forward( hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1) - # Process mid history latents + # 4. Process mid history latents if latents_history_mid is not None and indices_latents_history_mid is not None: latents_history_mid = latents_history_mid.to(hidden_states) latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) @@ -707,7 +714,7 @@ def forward( hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1) - # Process long history latents + # 5. Process long history latents if latents_history_long is not None and indices_latents_history_long is not None: latents_history_long = latents_history_long.to(hidden_states) latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) @@ -727,10 +734,6 @@ def forward( hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1) - post_patch_num_frames = sum(post_patch_num_frames_list) - post_patch_height = sum(post_patch_height_list) - post_patch_width = sum(post_patch_width_list) - original_context_length = sum(original_context_length_list) history_context_length = hidden_states.shape[1] - original_context_length if indices_hidden_states is not None and self.zero_history_timestep: @@ -762,7 +765,7 @@ def forward( if timestep_proj.ndim == 4: timestep_proj = timestep_proj.permute(0, 2, 1, 3) - # 4. Transformer blocks + # 6. Transformer blocks hidden_states = hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous() rotary_emb = rotary_emb.contiguous() @@ -786,27 +789,11 @@ def forward( original_context_length, ) - # 5. Output norm, projection & unpatchify - if temb.ndim == 3: - temb = temb[:, -original_context_length:, :] - shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) - shift = shift.squeeze(2) - scale = scale.squeeze(2) - else: - # batch_size, inner_dim - shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = hidden_states[:, -original_context_length:, :] - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + # 7. Normalization + hidden_states = self.norm_out(hidden_states, temb, original_context_length) hidden_states = self.proj_out(hidden_states) + # 8. Unpatchify hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 5c2b1ffb09d7..5951784d7755 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -1336,10 +1336,10 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=guidance_scale, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, + indices_hidden_states=indices_hidden_states.unsqueeze(0), + indices_latents_history_short=indices_latents_history_short.unsqueeze(0), + indices_latents_history_mid=indices_latents_history_mid.unsqueeze(0), + indices_latents_history_long=indices_latents_history_long.unsqueeze(0), latents_history_short=latents_history_short, latents_history_mid=latents_history_mid, latents_history_long=latents_history_long, @@ -1366,10 +1366,10 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, timesteps=timesteps, guidance_scale=guidance_scale, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, + indices_hidden_states=indices_hidden_states.unsqueeze(0), + indices_latents_history_short=indices_latents_history_short.unsqueeze(0), + indices_latents_history_mid=indices_latents_history_mid.unsqueeze(0), + indices_latents_history_long=indices_latents_history_long.unsqueeze(0), latents_history_short=latents_history_short, latents_history_mid=latents_history_mid, latents_history_long=latents_history_long, From cbe52d5cf8ec7722e3b931c12eb32ef662dd54b0 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:47:38 +0800 Subject: [PATCH 038/107] Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/helios/pipeline_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 5951784d7755..719cb428d8f7 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -909,7 +909,7 @@ def __call__( zero_steps: int | None = 1, # ------------ DMD ------------ use_dmd: bool = False, - is_skip_first_section: bool = False, + is_skip_first_chunk: bool = False, is_amplify_first_chunk: bool = False, ): r""" From 20eeed6b7712da3d24614ed516ca9ca36315e3cd Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 03:49:05 +0000 Subject: [PATCH 039/107] change "is_keep_x0" to "keep_first_frame" --- 0_temp_helios_test/infer_helios.py | 10 +++++----- src/diffusers/pipelines/helios/pipeline_helios.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index a4a730a2dd53..2048c281c46c 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -279,7 +279,7 @@ def parse_list_input(input_string): # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, - is_keep_x0=True, + keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, @@ -340,7 +340,7 @@ def parse_list_input(input_string): # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, - is_keep_x0=True, + keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, @@ -409,7 +409,7 @@ def parse_list_input(input_string): # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, - is_keep_x0=True, + keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, @@ -467,7 +467,7 @@ def parse_list_input(input_string): # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, - is_keep_x0=True, + keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, @@ -510,7 +510,7 @@ def parse_list_input(input_string): # stage 1 history_sizes=[16, 2, 1], latent_window_size=args.latent_window_size, - is_keep_x0=True, + keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 719cb428d8f7..910141729be2 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -898,7 +898,7 @@ def __call__( history_sizes: list = [16, 2, 1], latent_window_size: int = 9, use_dynamic_shifting: bool = False, - is_keep_x0: bool = True, + keep_first_frame: bool = True, # ------------ Stage 2 ------------ is_enable_stage2: bool = False, stage2_num_stages: int = 3, @@ -1138,7 +1138,7 @@ def __call__( history_video = None total_generated_latent_frames = 0 - if not is_keep_x0: + if not keep_first_frame: history_sizes[-1] = history_sizes[-1] + 1 history_latents = torch.zeros( batch_size, @@ -1201,7 +1201,7 @@ def __call__( is_first_section = k == 0 is_second_section = k == 1 - if is_keep_x0: + if keep_first_frame: if is_first_section: history_sizes_first_section = [1] + history_sizes.copy() history_latents_first_section = torch.zeros( @@ -1393,7 +1393,7 @@ def __call__( progress_bar=progress_bar, ) - if is_keep_x0 and ( + if keep_first_frame and ( (is_first_section and image_latents is None) or (is_skip_first_section and is_second_section) ): image_latents = latents[:, :, 0:1, :, :] From 38d50d217fac54c306340b3a45c9c832bd306f27 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 04:03:48 +0000 Subject: [PATCH 040/107] use a more intuitive name --- 0_temp_helios_test/infer_helios.py | 24 +++--- .../pipelines/helios/pipeline_helios.py | 80 +++++++++---------- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 2048c281c46c..c6f83bf3c96e 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -71,7 +71,7 @@ def parse_args(): parser.add_argument("--use_zero_init", action="store_true") parser.add_argument("--zero_steps", type=int, default=1) # stage 1 - parser.add_argument("--latent_window_size", type=int, default=9) + parser.add_argument("--num_latent_frames_per_chunk", type=int, default=9) # stage 2 parser.add_argument("--is_enable_stage2", action="store_true") parser.add_argument("--stage2_num_stages", type=int, default=3) @@ -81,7 +81,7 @@ def parse_args(): parser.add_argument("--stage2_num_inference_steps_list", type=int, nargs="+", default=[20, 20, 20]) # stage 3 parser.add_argument("--is_enable_stage3", action="store_true") - parser.add_argument("--is_skip_first_section", action="store_true") + parser.add_argument("--is_skip_first_chunk", action="store_true") parser.add_argument("--is_amplify_first_chunk", action="store_true") # === Prompts === @@ -278,7 +278,7 @@ def parse_list_input(input_string): generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 history_sizes=[16, 2, 1], - latent_window_size=args.latent_window_size, + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 @@ -287,7 +287,7 @@ def parse_list_input(input_string): stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_section=args.is_skip_first_section, + is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, @@ -339,7 +339,7 @@ def parse_list_input(input_string): generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 history_sizes=[16, 2, 1], - latent_window_size=args.latent_window_size, + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 @@ -348,7 +348,7 @@ def parse_list_input(input_string): stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_section=args.is_skip_first_section, + is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, @@ -408,7 +408,7 @@ def parse_list_input(input_string): generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 history_sizes=[16, 2, 1], - latent_window_size=args.latent_window_size, + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 @@ -417,7 +417,7 @@ def parse_list_input(input_string): stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_section=args.is_skip_first_section, + is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, @@ -466,7 +466,7 @@ def parse_list_input(input_string): generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 history_sizes=[16, 2, 1], - latent_window_size=args.latent_window_size, + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 @@ -475,7 +475,7 @@ def parse_list_input(input_string): stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_section=args.is_skip_first_section, + is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, @@ -509,7 +509,7 @@ def parse_list_input(input_string): generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 history_sizes=[16, 2, 1], - latent_window_size=args.latent_window_size, + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 @@ -518,7 +518,7 @@ def parse_list_input(input_string): stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_section=args.is_skip_first_section, + is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 910141729be2..462f120653b8 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -443,7 +443,7 @@ def prepare_video_latents( video: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, - latent_window_size: int, + num_latent_frames_per_chunk: int, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | list[torch.Generator] | None = None, @@ -453,13 +453,13 @@ def prepare_video_latents( video = video.to(device=device, dtype=self.vae.dtype) if latents is None: num_frames = video.shape[2] - min_frames = (latent_window_size - 1) * 4 + 1 + min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1 num_chunks = num_frames // min_frames if num_chunks == 0: raise ValueError( f"Video must have at least {min_frames} frames " f"(got {num_frames} frames). " - f"Required: (latent_window_size - 1) * 4 + 1 = ({latent_window_size} - 1) * 4 + 1 = {min_frames}" + f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}" ) total_valid_frames = num_chunks * min_frames start_frame = num_frames - total_valid_frames @@ -896,7 +896,7 @@ def __call__( interpolation_steps: int = 3, # ------------ Stage 1 ------------ history_sizes: list = [16, 2, 1], - latent_window_size: int = 9, + num_latent_frames_per_chunk: int = 9, use_dynamic_shifting: bool = False, keep_first_frame: bool = True, # ------------ Stage 2 ------------ @@ -1092,7 +1092,7 @@ def __call__( video, latents_mean=latents_mean, latents_std=latents_std, - latent_window_size=latent_window_size, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, dtype=torch.float32, device=device, generator=generator, @@ -1110,10 +1110,10 @@ def __call__( ) noisy_latents_chunks = [] - num_latent_chunks = video_latents.shape[2] // latent_window_size + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk for i in range(num_latent_chunks): - chunk_start = i * latent_window_size - chunk_end = chunk_start + latent_window_size + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] chunk_frames = latent_chunk.shape[2] @@ -1133,8 +1133,8 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 - num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) history_video = None total_generated_latent_frames = 0 @@ -1164,13 +1164,13 @@ def __call__( # 6. Denoising loop if use_interpolate_prompt: - if num_latent_sections < max(interpolate_cumulative_list): - num_latent_sections = sum(interpolate_cumulative_list) - print(f"Update num_latent_sections to: {num_latent_sections}") + if num_latent_chunk < max(interpolate_cumulative_list): + num_latent_chunk = sum(interpolate_cumulative_list) + print(f"Update num_latent_chunk to: {num_latent_chunk}") - for k in range(num_latent_sections): + for k in range(num_latent_chunk): if use_interpolate_prompt: - assert num_latent_sections >= max(interpolate_cumulative_list) + assert num_latent_chunk >= max(interpolate_cumulative_list) current_interval_idx = 0 for idx, cumulative_val in enumerate(interpolate_cumulative_list): @@ -1199,62 +1199,62 @@ def __call__( else: prompt_embeds = all_prompt_embeds - is_first_section = k == 0 - is_second_section = k == 1 + is_first_chunk = k == 0 + is_second_chunk = k == 1 if keep_first_frame: - if is_first_section: - history_sizes_first_section = [1] + history_sizes.copy() - history_latents_first_section = torch.zeros( + if is_first_chunk: + history_sizes_first_chunk = [1] + history_sizes.copy() + history_latents_first_chunk = torch.zeros( batch_size, num_channels_latents, - sum(history_sizes_first_section), + sum(history_sizes_first_chunk), height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, device=device, dtype=torch.float32, ) if fake_image_latents is not None: - history_latents_first_section = torch.cat( - [history_latents_first_section, fake_image_latents], dim=2 + history_latents_first_chunk = torch.cat( + [history_latents_first_chunk, fake_image_latents], dim=2 ) if video_latents is not None: - history_frames = history_latents_first_section.shape[2] + history_frames = history_latents_first_chunk.shape[2] video_frames = video_latents.shape[2] if video_frames < history_frames: keep_frames = history_frames - video_frames - history_latents_first_section = torch.cat( - [history_latents_first_section[:, :, :keep_frames, :, :], video_latents], dim=2 + history_latents_first_chunk = torch.cat( + [history_latents_first_chunk[:, :, :keep_frames, :, :], video_latents], dim=2 ) else: - history_latents_first_section = video_latents + history_latents_first_chunk = video_latents - indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) ( indices_prefix, indices_latents_history_long, indices_latents_history_mid, indices_latents_history_1x, indices_hidden_states, - ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) latents_prefix, latents_history_long, latents_history_mid, latents_history_1x = ( - history_latents_first_section[:, :, -sum(history_sizes_first_section) :].split( - history_sizes_first_section, dim=2 + history_latents_first_chunk[:, :, -sum(history_sizes_first_chunk) :].split( + history_sizes_first_chunk, dim=2 ) ) if image_latents is not None: latents_prefix = image_latents latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: - indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) ( indices_prefix, indices_latents_history_long, indices_latents_history_mid, indices_latents_history_1x, indices_hidden_states, - ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) latents_prefix = image_latents @@ -1263,13 +1263,13 @@ def __call__( ].split(history_sizes, dim=2) latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: - indices = torch.arange(0, sum([*history_sizes, latent_window_size])) + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) ( indices_latents_history_long, indices_latents_history_mid, indices_latents_history_short, indices_hidden_states, - ) = indices.split([*history_sizes, latent_window_size], dim=0) + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) latents_history_long, latents_history_mid, latents_history_short = history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) @@ -1323,7 +1323,7 @@ def __call__( else: num_inference_steps = ( sum(stage2_num_inference_steps_list) * 2 - if is_amplify_first_chunk and use_dmd and is_first_section + if is_amplify_first_chunk and use_dmd and is_first_chunk else sum(stage2_num_inference_steps_list) ) @@ -1353,7 +1353,7 @@ def __call__( zero_steps=zero_steps, # -------------- DMD -------------- use_dmd=use_dmd, - is_amplify_first_chunk=is_amplify_first_chunk and is_first_section, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, # ------------ Callback ------------ callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -1386,7 +1386,7 @@ def __call__( use_dmd=use_dmd, dmd_sigmas=dmd_sigmas, dmd_timesteps=dmd_timesteps, - is_amplify_first_chunk=is_amplify_first_chunk and is_first_section, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, # ------------ Callback ------------ callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -1394,7 +1394,7 @@ def __call__( ) if keep_first_frame and ( - (is_first_section and image_latents is None) or (is_skip_first_section and is_second_section) + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) ): image_latents = latents[:, :, 0:1, :, :] @@ -1404,7 +1404,7 @@ def __call__( index_slice = ( slice(None), slice(None), - slice(-latent_window_size, None), + slice(-num_latent_frames_per_chunk, None), ) current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean From 378d8a94607aa840b2a32206788ccb031410c552 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 09:26:50 +0000 Subject: [PATCH 041/107] refactor dynamic_time_shifting --- .../pipelines/helios/pipeline_helios.py | 226 ++++++++---------- src/diffusers/schedulers/scheduling_helios.py | 40 ++++ 2 files changed, 143 insertions(+), 123 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 462f120653b8..f2c15d4919dd 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -13,10 +13,12 @@ # limitations under the License. import html +import inspect import math from itertools import accumulate from typing import Any, Callable +import numpy as np import regex as re import torch import torch.nn.functional as F @@ -117,45 +119,64 @@ def calculate_shift( return mu -def apply_schedule_shift( - image_seq_len, - sigmas, - sigmas_two=None, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, - exp_max: float = 7.0, - is_exponential: bool = False, - mu: float = None, - return_mu: bool = False, +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, ): - if mu is None: - # Resolution-dependent shifting of timestep schedules as per section 5.3.2 of SD3 paper - mu = calculate_shift( - image_seq_len, - base_seq_len, - max_seq_len, - base_shift, - max_shift, - ) - if is_exponential: - mu = min(mu, math.log(exp_max)) - mu = math.exp(mu) - - if sigmas_two is not None: - sigmas = (sigmas * mu) / (1 + (mu - 1) * sigmas) - sigmas_two = (sigmas_two * mu) / (1 + (mu - 1) * sigmas_two) - if return_mu: - return sigmas, sigmas_two, mu - else: - return sigmas, sigmas_two + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: - sigmas = (sigmas * mu) / (1 + (mu - 1) * sigmas) - if return_mu: - return sigmas, mu - else: - return sigmas + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): @@ -523,15 +544,11 @@ def stage1_sample( transformer_dtype: torch.dtype = None, use_dynamic_shifting: bool = False, generator: torch.Generator | None = None, + num_warmup_steps: int | None = None, # ------------ CFG Zero ------------ use_cfg_zero_star: bool | None = False, use_zero_init: bool | None = True, zero_steps: int | None = 1, - # -------------- DMD -------------- - use_dmd: bool = False, - dmd_sigmas: torch.Tensor = None, - dmd_timesteps: torch.Tensor = None, - is_amplify_first_chunk: bool = False, # ------------ Callback ------------ callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], @@ -563,7 +580,7 @@ def stage1_sample( return_dict=False, )[0] - if self.do_classifier_free_guidance and not use_dmd: + if self.do_classifier_free_guidance: with self.transformer.cache_context("uncond"): noise_uncond = self.transformer( hidden_states=latent_model_input, @@ -596,26 +613,12 @@ def stage1_sample( else: noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - if isinstance(self.scheduler, HeliosScheduler): - latents = self.scheduler.step( - noise_pred, - t, - latents, - generator=generator, - return_dict=False, - cur_sampling_step=i, - dmd_noisy_tensor=randn_tensor(noise_pred.shape, generator=generator, device=device), - dmd_sigmas=dmd_sigmas, - dmd_timesteps=dmd_timesteps, - all_timesteps=timesteps, - )[0] - else: - latents = self.scheduler.step( - noise_pred, - t, - latents, - return_dict=False, - )[0] + latents = self.scheduler.step( + noise_pred, + t, + latents, + return_dict=False, + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -627,7 +630,8 @@ def stage1_sample( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - progress_bar.update() + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() if XLA_AVAILABLE: xm.mark_step() @@ -688,8 +692,23 @@ def stage2_sample( i = 0 for i_s in range(stage2_num_stages): + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) self.scheduler.set_timesteps( - stage2_num_inference_steps_list[i_s], i_s, device=device, is_amplify_first_chunk=is_amplify_first_chunk + stage2_num_inference_steps_list[i_s], + i_s, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk, ) if i_s > 0: @@ -715,26 +734,6 @@ def stage2_sample( if use_dmd: start_point_list.append(latents) - if use_dynamic_shifting: - patch_size = self.transformer.config.patch_size - image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( - patch_size[0] * patch_size[1] * patch_size[2] - ) - temp_sigmas = apply_schedule_shift( - image_seq_len, - self.scheduler.sigmas, - base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), - max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), - base_shift=self.scheduler.config.get("base_shift", 0.5), - max_shift=self.scheduler.config.get("max_shift", 1.15), - ) - temp_timesteps = self.scheduler.timesteps_per_stage[i_s].min() + temp_sigmas[:-1] * ( - self.scheduler.timesteps_per_stage[i_s].max() - self.scheduler.timesteps_per_stage[i_s].min() - ) - - self.scheduler.sigmas = temp_sigmas - self.scheduler.timesteps = temp_timesteps - timesteps = self.scheduler.timesteps for idx, t in enumerate(timesteps): @@ -865,6 +864,7 @@ def __call__( width: int = 640, num_frames: int = 132, num_inference_steps: int = 50, + sigmas: list[float] = None, guidance_scale: float = 5.0, num_videos_per_prompt: int | None = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -1054,7 +1054,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds[0].unsqueeze(0) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # 4. Prepare image + # 4. Prepare image or video if image is not None: image = self.video_processor.preprocess(image, height=height, width=width) image_latents, fake_image_latents = self.prepare_image_latents( @@ -1287,38 +1287,22 @@ def __call__( ) if not is_enable_stage2: - self.scheduler.set_timesteps(num_inference_steps, device=device) - - if use_dynamic_shifting: - patch_size = self.transformer.config.patch_size - image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( - patch_size[0] * patch_size[1] * patch_size[2] - ) - sigmas = torch.linspace( - 0.999, 0.0, steps=num_inference_steps + 1, dtype=torch.float32, device=device - )[:-1] - sigmas = apply_schedule_shift( - image_seq_len, - sigmas, - base_seq_len=self.scheduler.config.get("base_image_seq_len", 256), - max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096), - base_shift=self.scheduler.config.get("base_shift", 0.5), - max_shift=self.scheduler.config.get("max_shift", 1.15), - ) - timesteps = sigmas * 1000.0 # rescale to [0, 1000.0) - timesteps = timesteps.to(device) - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.scheduler.timesteps = timesteps - self.scheduler.sigmas = sigmas - - timesteps = self.scheduler.timesteps - - dmd_sigmas = None - dmd_timesteps = None - if use_dmd: - dmd_sigmas = self.scheduler.sigmas.to(self.transformer.device) - dmd_timesteps = self.scheduler.timesteps.to(self.transformer.device) - + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) else: num_inference_steps = ( @@ -1378,15 +1362,11 @@ def __call__( transformer_dtype=transformer_dtype, use_dynamic_shifting=use_dynamic_shifting, generator=generator, + num_warmup_steps=num_warmup_steps, # ------------ CFG Zero ------------ use_cfg_zero_star=use_cfg_zero_star, use_zero_init=use_zero_init, zero_steps=zero_steps, - # -------------- DMD -------------- - use_dmd=use_dmd, - dmd_sigmas=dmd_sigmas, - dmd_timesteps=dmd_timesteps, - is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, # ------------ Callback ------------ callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index e2bc80d00f99..c18d6cb14d97 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -14,6 +14,7 @@ import math from dataclasses import dataclass +from typing import Literal import numpy as np import torch @@ -55,6 +56,8 @@ def __init__( use_flow_sigmas: bool = True, version: str = "v1", scheduler_type: str = "unipc", # ["euler", "unipc", "dmd"] + use_dynamic_shifting: bool = False, + time_shift_type: Literal["exponential", "linear"] = "linear", ): self.version = version self.timestep_ratios = {} # The timestep ratio for each stage @@ -206,6 +209,7 @@ def set_timesteps( num_inference_steps: int, stage_index: int, device: str | torch.device = None, + mu: bool | None = None, is_amplify_first_chunk: bool = False, ): """ @@ -261,6 +265,42 @@ def set_timesteps( self.timesteps = self.timesteps[:-1] self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + if self.config.use_dynamic_shifting: + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + """ + Apply time shifting to the sigmas. + + Args: + mu (`float`): + The mu parameter for the time shift. + sigma (`float`): + The sigma parameter for the time shift. + t (`torch.Tensor`): + The input timesteps. + + Returns: + `torch.Tensor`: + The time-shifted timesteps. + """ + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + # ---------------------------------- Euler ---------------------------------- def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: From eff865c499bbacc12302037be716cf244bf6dc9c Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 11:35:03 +0000 Subject: [PATCH 042/107] remove use_dynamic_shifting args --- 0_temp_helios_test/infer_helios.py | 6 ------ 0_temp_helios_test/stage-1_i2v.sh | 1 - 0_temp_helios_test/stage-1_t2v.sh | 1 - 0_temp_helios_test/stage-1_v2v.sh | 1 - 0_temp_helios_test/stage-2_i2v.sh | 1 - 0_temp_helios_test/stage-2_t2v.sh | 1 - 0_temp_helios_test/stage-2_v2v.sh | 1 - 0_temp_helios_test/stage-3_i2v.sh | 1 - 0_temp_helios_test/stage-3_t2v.sh | 1 - 0_temp_helios_test/stage-3_v2v.sh | 1 - src/diffusers/pipelines/helios/pipeline_helios.py | 5 ----- 11 files changed, 20 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index c6f83bf3c96e..c194dbfd3da2 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -65,7 +65,6 @@ def parse_args(): parser.add_argument("--num_frames", type=int, default=99) parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=5.0) - parser.add_argument("--use_dynamic_shifting", action="store_true") # cfg zero parser.add_argument("--use_cfg_zero_star", action="store_true") parser.add_argument("--use_zero_init", action="store_true") @@ -280,7 +279,6 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, - use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, @@ -341,7 +339,6 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, - use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, @@ -410,7 +407,6 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, - use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, @@ -468,7 +464,6 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, - use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, @@ -511,7 +506,6 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, - use_dynamic_shifting=args.use_dynamic_shifting, # stage 2 is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh index 335cc832ae4a..968fb82a41b0 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -4,7 +4,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ - --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-1_t2v.sh b/0_temp_helios_test/stage-1_t2v.sh index 6520b8a0713a..43756af73580 100644 --- a/0_temp_helios_test/stage-1_t2v.sh +++ b/0_temp_helios_test/stage-1_t2v.sh @@ -3,7 +3,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ - --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh index 56a00631a988..25f882853f37 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -4,7 +4,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ - --use_dynamic_shifting \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 79a6515b5a24..9e235e7a5dbd 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -6,7 +6,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --is_enable_stage2 \ --stage2_num_inference_steps_list 20 20 20 \ - --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index c6fd3176e05d..36590dc1fce1 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -5,7 +5,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --is_enable_stage2 \ --stage2_num_inference_steps_list 20 20 20 \ - --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 261fdeed09b6..8f4181af607c 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -6,7 +6,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --is_enable_stage2 \ --stage2_num_inference_steps_list 20 20 20 \ - --use_dynamic_shifting \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index de40fc798414..f1adf67d337c 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -9,7 +9,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --is_enable_stage2 \ --stage2_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ - --use_dynamic_shifting \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/stage-3_t2v.sh index 29a74c371e0e..1f1efc04df41 100644 --- a/0_temp_helios_test/stage-3_t2v.sh +++ b/0_temp_helios_test/stage-3_t2v.sh @@ -8,7 +8,6 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --is_enable_stage2 \ --stage2_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ - --use_dynamic_shifting \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index f4b001908e17..11a1a5fb6fcf 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -9,7 +9,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --is_enable_stage2 \ --stage2_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ - --use_dynamic_shifting \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index f2c15d4919dd..b4da4b822a03 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -542,7 +542,6 @@ def stage1_sample( attention_kwargs: dict | None = None, device: torch.device | None = None, transformer_dtype: torch.dtype = None, - use_dynamic_shifting: bool = False, generator: torch.Generator | None = None, num_warmup_steps: int | None = None, # ------------ CFG Zero ------------ @@ -656,7 +655,6 @@ def stage2_sample( attention_kwargs: dict | None = None, device: torch.device | None = None, transformer_dtype: torch.dtype = None, - use_dynamic_shifting: bool = False, generator: torch.Generator | None = None, # ------------ CFG Zero ------------ use_cfg_zero_star: bool | None = False, @@ -897,7 +895,6 @@ def __call__( # ------------ Stage 1 ------------ history_sizes: list = [16, 2, 1], num_latent_frames_per_chunk: int = 9, - use_dynamic_shifting: bool = False, keep_first_frame: bool = True, # ------------ Stage 2 ------------ is_enable_stage2: bool = False, @@ -1330,7 +1327,6 @@ def __call__( attention_kwargs=attention_kwargs, device=device, transformer_dtype=transformer_dtype, - use_dynamic_shifting=use_dynamic_shifting, # ------------ CFG Zero ------------ use_cfg_zero_star=use_cfg_zero_star, use_zero_init=use_zero_init, @@ -1360,7 +1356,6 @@ def __call__( attention_kwargs=attention_kwargs, device=device, transformer_dtype=transformer_dtype, - use_dynamic_shifting=use_dynamic_shifting, generator=generator, num_warmup_steps=num_warmup_steps, # ------------ CFG Zero ------------ From 866c55ce60dbdcc13520002622f27b961dee1f4f Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 13:31:26 +0000 Subject: [PATCH 043/107] remove usage of UniPCMultistepScheduler --- 0_temp_helios_test/infer_helios.py | 1 - .../pipelines/helios/pipeline_helios.py | 74 ++----------------- src/diffusers/schedulers/scheduling_helios.py | 66 ++++++----------- 3 files changed, 30 insertions(+), 111 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index c194dbfd3da2..0dba9486646c 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -14,7 +14,6 @@ from diffusers import HeliosTransformer3DModel from diffusers import HeliosPipeline -from diffusers.schedulers.scheduling_helios import HeliosScheduler from diffusers import ContextParallelConfig from diffusers.models import AutoencoderKLWan diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index b4da4b822a03..e157ce97d2b0 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -28,7 +28,7 @@ from ...image_processor import PipelineImageInput from ...loaders import HeliosLoraLoaderMixin from ...models import AutoencoderKLWan, HeliosTransformer3DModel -from ...schedulers import HeliosScheduler, UniPCMultistepScheduler +from ...schedulers import HeliosScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -119,66 +119,6 @@ def calculate_shift( return mu -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: int | None = None, - device: str | torch.device | None = None, - timesteps: list[int] | None = None, - sigmas: list[float] | None = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`list[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`list[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): r""" Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. @@ -195,7 +135,7 @@ class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. transformer ([`HeliosTransformer3DModel`]): Conditional Transformer to denoise the input latents. - scheduler ([`UniPCMultistepScheduler`]): + scheduler ([`HeliosScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. @@ -210,7 +150,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, - scheduler: UniPCMultistepScheduler | HeliosScheduler, + scheduler: HeliosScheduler, transformer: HeliosTransformer3DModel, ): super().__init__() @@ -708,6 +648,7 @@ def stage2_sample( mu=mu, is_amplify_first_chunk=is_amplify_first_chunk, ) + timesteps = self.scheduler.timesteps if i_s > 0: height *= 2 @@ -732,8 +673,6 @@ def stage2_sample( if use_dmd: start_point_list.append(latents) - timesteps = self.scheduler.timesteps - for idx, t in enumerate(timesteps): timestep = t.expand(latents.shape[0]).to(torch.int64) @@ -1296,9 +1235,8 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu - ) + self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) + timesteps = self.scheduler.timesteps num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) else: diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index c18d6cb14d97..691d13943117 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -54,12 +54,10 @@ def __init__( disable_corrector: list[int] = [], solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, - version: str = "v1", scheduler_type: str = "unipc", # ["euler", "unipc", "dmd"] use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential", "linear"] = "linear", ): - self.version = version self.timestep_ratios = {} # The timestep ratio for each stage self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) self.sigmas_per_stage = {} # always uniform [1000, 0] @@ -140,16 +138,6 @@ def init_sigmas_for_each_stage(self): self.start_sigmas[i_s] = start_sigma self.end_sigmas[i_s] = end_sigma - if self.version == "v2": - new_start_indice = ( - len(self.sigmas) - torch.searchsorted(self.sigmas.flip(0), start_sigma, right=True) - ).item() - self.sigmas_per_stage[i_s] = self.sigmas[new_start_indice:end_indice] - self.timesteps_per_stage[i_s] = self.timesteps[new_start_indice:end_indice] - - if self.version == "v2": - return - # Determine the ratio of each stage according to flow length tot_distance = sum(stage_distance) for i_s in range(stages): @@ -207,8 +195,9 @@ def _sigma_to_t(self, sigma): def set_timesteps( self, num_inference_steps: int, - stage_index: int, + stage_index: int| None = None, device: str | torch.device = None, + sigmas: bool | None = None, mu: bool | None = None, is_amplify_first_chunk: bool = False, ): @@ -224,39 +213,28 @@ def set_timesteps( self.num_inference_steps = num_inference_steps self.init_sigmas() - if self.version == "v1": + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(np.float32) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: stage_timesteps = self.timesteps_per_stage[stage_index] - timestep_max = stage_timesteps[0].item() - timestep_min = stage_timesteps[-1].item() - timesteps = np.linspace( - timestep_max, - timestep_min, + stage_timesteps[0].item(), + stage_timesteps[-1].item(), num_inference_steps, ) - self.timesteps = torch.from_numpy(timesteps).to(device=device) stage_sigmas = self.sigmas_per_stage[stage_index] - sigma_max = stage_sigmas[0].item() - sigma_min = stage_sigmas[-1].item() - - ratios = np.linspace(sigma_max, sigma_min, num_inference_steps) - sigmas = torch.from_numpy(ratios).to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - else: - total_steps = len(self.timesteps_per_stage[stage_index]) - indices = np.linspace(0, total_steps - 1, num_inference_steps, dtype=int) + ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps) + sigmas = torch.from_numpy(ratios) - self.timesteps = self.timesteps_per_stage[stage_index][indices].to(device=device) - - if stage_index == (self.config.stages - 1): - sigmas = self.sigmas_per_stage[stage_index][indices].to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - else: - sigmas = self.sigmas_per_stage[stage_index][indices].to(device=device) - self.sigmas = torch.cat( - [sigmas, torch.tensor([self.ori_start_sigmas[stage_index + 1]], device=sigmas.device)] - ) + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) self._step_index = None self.reset_scheduler_history() @@ -266,10 +244,14 @@ def set_timesteps( self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 self.sigmas = self.time_shift(mu, 1.0, self.sigmas) - self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( - self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() - ) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift def time_shift(self, mu: float, sigma: float, t: torch.Tensor): From 54abd1c797abd2c25938b04524e23f71b637a025 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Fri, 27 Feb 2026 14:52:08 +0000 Subject: [PATCH 044/107] separate stage2 sample to HeliosPyramidPipeline --- 0_temp_helios_test/infer_helios.py | 318 +---- docs/source/en/api/pipelines/helios.md | 15 +- docs/source/en/using-diffusers/helios.md | 7 +- docs/source/zh/using-diffusers/helios.md | 7 +- src/diffusers/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 2 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/auto_pipeline.py | 3 +- src/diffusers/pipelines/helios/__init__.py | 2 + .../pipelines/helios/pipeline_helios.py | 481 ++----- .../helios/pipeline_helios_pyramid.py | 1134 +++++++++++++++++ src/diffusers/schedulers/scheduling_helios.py | 8 +- .../dummy_torch_and_transformers_objects.py | 15 + 13 files changed, 1314 insertions(+), 684 deletions(-) create mode 100644 src/diffusers/pipelines/helios/pipeline_helios_pyramid.py diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 0dba9486646c..37383f6c55ec 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -13,7 +13,7 @@ from tqdm import tqdm from diffusers import HeliosTransformer3DModel -from diffusers import HeliosPipeline +from diffusers import HeliosPipeline, HeliosPyramidPipeline from diffusers import ContextParallelConfig from diffusers.models import AutoencoderKLWan @@ -205,12 +205,20 @@ def main(): subfolder="vae", torch_dtype=torch.float32, ) - pipe = HeliosPipeline.from_pretrained( - args.base_model_path, - transformer=transformer, - vae=vae, - torch_dtype=args.weight_dtype, - ) + if not args.is_enable_stage2: + pipe = HeliosPipeline.from_pretrained( + args.base_model_path, + transformer=transformer, + vae=vae, + torch_dtype=args.weight_dtype, + ) + else: + pipe = HeliosPyramidPipeline.from_pretrained( + args.base_model_path, + transformer=transformer, + vae=vae, + torch_dtype=args.weight_dtype, + ) if args.lora_path is not None: pipe.load_lora_weights(args.lora_path, adapter_name="default") @@ -239,259 +247,32 @@ def main(): # transformer.set_attention_backend("flash") pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size)) - if args.debug_mode: - - def parse_list_input(input_string): - input_string = input_string.strip("[]").strip() - if "," in input_string: - return [int(x.strip()) for x in input_string.split(",") if x.strip()] - else: - return [int(x.strip()) for x in input_string.split() if x.strip()] - - while True: - user_input = input("Please enter stage2_num_inference_steps_list (e.g., 10 20 30): ").strip() - - if user_input.lower() in ["q", "quit", "exit"]: - break - - try: - pyramid_steps = parse_list_input(user_input) - print(f"✅ Parsing successful: {pyramid_steps}") - except ValueError as e: - print(f"❌ Input format error: {e}") - print("Please re-enter...\n") - continue - - args.stage2_num_inference_steps_list = pyramid_steps - - with torch.no_grad(): - output = pipe( - prompt=prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - # stage 2 - is_enable_stage2=args.is_enable_stage2, - stage2_num_stages=args.stage2_num_stages, - stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - # stage 3 - use_dmd=args.is_enable_stage3, - is_skip_first_chunk=args.is_skip_first_chunk, - is_amplify_first_chunk=args.is_amplify_first_chunk, - # cfg zero - use_cfg_zero_star=args.use_cfg_zero_star, - use_zero_init=args.use_zero_init, - zero_steps=args.zero_steps, - # i2v - image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, - ).frames[0] - - if not args.enable_parallelism or rank == 0: - file_count = len( - [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] - ) - output_path = os.path.join( - args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" - ) - export_to_video(output, output_path, fps=24) - elif args.prompt_txt_path is not None: - with open(args.prompt_txt_path, "r") as f: - prompt_list = [line.strip() for line in f.readlines() if line.strip()] - if not args.enable_parallelism: - prompt_list_with_idx = [(i, prompt) for i, prompt in enumerate(prompt_list)] - prompt_list_with_idx = prompt_list_with_idx[rank::world_size] - else: - prompt_list_with_idx = [(i, prompt) for i, prompt in enumerate(prompt_list)] - - for idx, prompt in tqdm(prompt_list_with_idx, desc="Processing prompts"): - output_path = os.path.join(args.output_folder, f"{idx}.mp4") - if os.path.exists(output_path): - print("skipping!") - continue - - with torch.no_grad(): - try: - output = pipe( - prompt=prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - # stage 2 - is_enable_stage2=args.is_enable_stage2, - stage2_num_stages=args.stage2_num_stages, - stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - # stage 3 - use_dmd=args.is_enable_stage3, - is_skip_first_chunk=args.is_skip_first_chunk, - is_amplify_first_chunk=args.is_amplify_first_chunk, - # cfg zero - use_cfg_zero_star=args.use_cfg_zero_star, - use_zero_init=args.use_zero_init, - zero_steps=args.zero_steps, - # i2v - image=load_image(image_path).resize((args.width, args.height)) - if image_path is not None - else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, - ).frames[0] - except Exception: - continue - if not args.enable_parallelism or rank == 0: - export_to_video(output, output_path, fps=24) - elif args.interactive_prompt_csv_path is not None: - df = pd.read_csv(args.interactive_prompt_csv_path) - - df = df.sort_values(by=["id", "prompt_index"]) - all_video_ids = df["id"].unique() - - if not args.enable_parallelism: - my_video_ids = all_video_ids[rank::world_size] + with torch.no_grad(): + if not args.is_enable_stage2: + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, # 73 109 145 181 215 + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + # stage 1 + history_sizes=[16, 2, 1], + num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, + keep_first_frame=True, + is_skip_first_chunk=args.is_skip_first_chunk, + # i2v + image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, + # t2v + video=load_video(video_path) if video_path is not None else None, + # interpolate_prompt + use_interpolate_prompt=args.use_interpolate_prompt, + interpolation_steps=args.interpolation_steps, + interpolate_time_list=interpolate_time_list, + ).frames[0] else: - my_video_ids = all_video_ids - - for video_id in tqdm(my_video_ids, desc="Processing prompts"): - output_path = os.path.join(args.output_folder, f"{video_id}.mp4") - - if os.path.exists(output_path): - print(f"skipping {output_path}!") - continue - - group_df = df[df["id"] == video_id] - - if "refined_prompt" in df.columns: - prompt_list = group_df["refined_prompt"].fillna(group_df["prompt"]).tolist() - else: - prompt_list = group_df["prompt"].tolist() - interpolate_time_list = [7] * len(prompt_list) - - with torch.no_grad(): - try: - output = pipe( - prompt=prompt_list, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - # stage 2 - is_enable_stage2=args.is_enable_stage2, - stage2_num_stages=args.stage2_num_stages, - stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - # stage 3 - use_dmd=args.is_enable_stage3, - is_skip_first_chunk=args.is_skip_first_chunk, - is_amplify_first_chunk=args.is_amplify_first_chunk, - # cfg zero - use_cfg_zero_star=args.use_cfg_zero_star, - use_zero_init=args.use_zero_init, - zero_steps=args.zero_steps, - # i2v - image=load_image(image_path).resize((args.width, args.height)) - if image_path is not None - else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, - ).frames[0] - except Exception: - continue - if not args.enable_parallelism or rank == 0: - export_to_video(output, output_path, fps=24) - elif args.image_prompt_csv_path is not None: - df = pd.read_csv(args.image_prompt_csv_path) - if not args.enable_parallelism: - df = df.iloc[rank::world_size] - - for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing prompts"): - # output_path = os.path.join(args.output_folder, f"{idx}.mp4") - output_path = os.path.join(args.output_folder, f"{row['id']}.mp4") - if os.path.exists(output_path): - print("skipping!") - continue - - prompt = row.get("refined_prompt") or row["prompt"] - image_path = os.path.join(args.base_image_prompt_path, row["image_name"]) - - with torch.no_grad(): - try: - output = pipe( - prompt=prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - # stage 2 - is_enable_stage2=args.is_enable_stage2, - stage2_num_stages=args.stage2_num_stages, - stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, - # stage 3 - use_dmd=args.is_enable_stage3, - is_skip_first_chunk=args.is_skip_first_chunk, - is_amplify_first_chunk=args.is_amplify_first_chunk, - # cfg zero - use_cfg_zero_star=args.use_cfg_zero_star, - use_zero_init=args.use_zero_init, - zero_steps=args.zero_steps, - # i2v - image=load_image(image_path).resize((args.width, args.height)) - if image_path is not None - else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, - ).frames[0] - except Exception: - continue - if not args.enable_parallelism or rank == 0: - export_to_video(output, output_path, fps=24) - else: - with torch.no_grad(): output = pipe( prompt=prompt, negative_prompt=args.negative_prompt, @@ -505,13 +286,12 @@ def parse_list_input(input_string): history_sizes=[16, 2, 1], num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, keep_first_frame=True, + is_skip_first_chunk=args.is_skip_first_chunk, # stage 2 - is_enable_stage2=args.is_enable_stage2, stage2_num_stages=args.stage2_num_stages, stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, # stage 3 use_dmd=args.is_enable_stage3, - is_skip_first_chunk=args.is_skip_first_chunk, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, @@ -527,14 +307,14 @@ def parse_list_input(input_string): interpolate_time_list=interpolate_time_list, ).frames[0] - if not args.enable_parallelism or rank == 0: - file_count = len( - [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] - ) - output_path = os.path.join( - args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" - ) - export_to_video(output, output_path, fps=24) + if not args.enable_parallelism or rank == 0: + file_count = len( + [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] + ) + output_path = os.path.join( + args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" + ) + export_to_video(output, output_path, fps=24) print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB") diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 418aaced1fb0..50a65af08b72 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -248,12 +248,12 @@ The example below demonstrates how to use Helios-Mid to generate video based on ```python import torch -from diffusers import AutoModel, HeliosPipeline +from diffusers import AutoModel, HeliosPyramidPipeline from diffusers.utils import export_to_video, load_video, load_image vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32) -pipeline = HeliosPipeline.from_pretrained( +pipeline = HeliosPyramidPipeline.from_pretrained( "BestWishYsh/Helios-Mid", vae=vae, torch_dtype=torch.bfloat16 @@ -354,12 +354,12 @@ The example below demonstrates how to use Helios-Distilled to generate video bas ```python import torch -from diffusers import AutoModel, HeliosPipeline +from diffusers import AutoModel, HeliosPyramidPipeline from diffusers.utils import export_to_video, load_video, load_image vae = AutoModel.from_pretrained("BestWishYsh//Helios-Distilled", subfolder="vae", torch_dtype=torch.float32) -pipeline = HeliosPipeline.from_pretrained( +pipeline = HeliosPyramidPipeline.from_pretrained( "BestWishYsh//Helios-Distilled", vae=vae, torch_dtype=torch.bfloat16 @@ -458,6 +458,13 @@ export_to_video(output, "output_v2v.mp4", fps=24) - all - __call__ +## HeliosPyramidPipeline + +[[autodoc]] HeliosPyramidPipeline + + - all + - __call__ + ## HeliosPipelineOutput [[autodoc]] pipelines.Helios.pipeline_output.HeliosPipelineOutput diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index 0f9ea997e248..d95344003ef8 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -24,9 +24,8 @@ This guide will walk you through using Helios for use cases. Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. ```python -# !pip install Helios_eva_clip insightface facexlib import torch -from diffusers import HeliosPipeline +from diffusers import HeliosPipeline, HeliosPyramidPipeline from huggingface_hub import snapshot_download # For Best Quality @@ -36,12 +35,12 @@ pipe.to("cuda") # Intermediate Weight snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") -pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) pipe.to("cuda") # For Best Efficiency snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") -pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) pipe.to("cuda") ``` diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 5e18f02bfe85..20a43b40da14 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -24,9 +24,8 @@ specific language governing permissions and limitations under the License. 模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。 ```python -# !pip install Helios_eva_clip insightface facexlib import torch -from diffusers import HeliosPipeline +from diffusers import HeliosPipeline, HeliosPyramidPipeline from huggingface_hub import snapshot_download # For Best Quality @@ -36,12 +35,12 @@ pipe.to("cuda") # Intermediate Weight snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid") -pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16) pipe.to("cuda") # For Best Efficiency snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled") -pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) +pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16) pipe.to("cuda") ``` diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 012e3d27b97c..ea6e06471379 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -518,6 +518,7 @@ "FluxPriorReduxPipeline", "GlmImagePipeline", "HeliosPipeline", + "HeliosPyramidPipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -1263,6 +1264,7 @@ FluxPriorReduxPipeline, GlmImagePipeline, HeliosPipeline, + HeliosPyramidPipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4f94b094c8f5..83865ba7c701 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4323,7 +4323,7 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): class HeliosLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`]. + Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. """ _lora_loadable_modules = ["transformer"] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0002493e97e6..08cb28a6237a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -237,7 +237,7 @@ "EasyAnimateInpaintPipeline", "EasyAnimateControlPipeline", ] - _import_structure["helios"] = ["HeliosPipeline"] + _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuan_video"] = [ @@ -668,7 +668,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline - from .helios import HeliosPipeline + from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 52af1b1c6b46..72151dc40a53 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -54,7 +54,7 @@ ) from .flux2 import Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline -from .helios import HeliosPipeline +from .helios import HeliosPipeline, HeliosPyramidPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -176,6 +176,7 @@ ("cogview4", CogView4Pipeline), ("glm_image", GlmImagePipeline), ("helios", HeliosPipeline), + ("helios-pyramid", HeliosPyramidPipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/helios/__init__.py b/src/diffusers/pipelines/helios/__init__.py index a540d1462389..ae08f5997279 100644 --- a/src/diffusers/pipelines/helios/__init__.py +++ b/src/diffusers/pipelines/helios/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_helios"] = ["HeliosPipeline"] + _import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -32,6 +33,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_helios import HeliosPipeline + from .pipeline_helios_pyramid import HeliosPyramidPipeline else: import sys diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index e157ce97d2b0..43d7732cf10f 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -13,15 +13,12 @@ # limitations under the License. import html -import inspect -import math from itertools import accumulate from typing import Any, Callable import numpy as np import regex as re import torch -import torch.nn.functional as F from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -465,308 +462,6 @@ def sample_block_noise(self, batch_size, channel, num_frames, height, width): noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) return noise - def stage1_sample( - self, - latents: torch.Tensor = None, - prompt_embeds: torch.Tensor = None, - negative_prompt_embeds: torch.Tensor = None, - timesteps: torch.Tensor = None, - guidance_scale: float | None = 5.0, - indices_hidden_states: torch.Tensor = None, - indices_latents_history_short: torch.Tensor = None, - indices_latents_history_mid: torch.Tensor = None, - indices_latents_history_long: torch.Tensor = None, - latents_history_short: torch.Tensor = None, - latents_history_mid: torch.Tensor = None, - latents_history_long: torch.Tensor = None, - attention_kwargs: dict | None = None, - device: torch.device | None = None, - transformer_dtype: torch.dtype = None, - generator: torch.Generator | None = None, - num_warmup_steps: int | None = None, - # ------------ CFG Zero ------------ - use_cfg_zero_star: bool | None = False, - use_zero_init: bool | None = True, - zero_steps: int | None = 1, - # ------------ Callback ------------ - callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - progress_bar=None, - ): - batch_size = latents.shape[0] - - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.expand(latents.shape[0]) - - latent_model_input = latents.to(transformer_dtype) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - with self.transformer.cache_context("uncond"): - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if use_cfg_zero_star: - noise_pred_text = noise_pred - positive_flat = noise_pred_text.view(batch_size, -1) - negative_flat = noise_uncond.view(batch_size, -1) - - alpha = optimized_scale(positive_flat, negative_flat) - alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) - alpha = alpha.to(noise_pred_text.dtype) - - if (i <= zero_steps) and use_zero_init: - noise_pred = noise_pred_text * 0.0 - else: - noise_pred = noise_uncond * alpha + guidance_scale * (noise_pred_text - noise_uncond * alpha) - else: - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - latents = self.scheduler.step( - noise_pred, - t, - latents, - return_dict=False, - )[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - return latents - - def stage2_sample( - self, - latents: torch.Tensor = None, - stage2_num_stages: int = None, - stage2_num_inference_steps_list: list[int] = None, - prompt_embeds: torch.Tensor = None, - negative_prompt_embeds: torch.Tensor = None, - guidance_scale: float | None = 5.0, - indices_hidden_states: torch.Tensor = None, - indices_latents_history_short: torch.Tensor = None, - indices_latents_history_mid: torch.Tensor = None, - indices_latents_history_long: torch.Tensor = None, - latents_history_short: torch.Tensor = None, - latents_history_mid: torch.Tensor = None, - latents_history_long: torch.Tensor = None, - attention_kwargs: dict | None = None, - device: torch.device | None = None, - transformer_dtype: torch.dtype = None, - generator: torch.Generator | None = None, - # ------------ CFG Zero ------------ - use_cfg_zero_star: bool | None = False, - use_zero_init: bool | None = True, - zero_steps: int | None = 1, - # -------------- DMD -------------- - use_dmd: bool = False, - is_amplify_first_chunk: bool = False, - # ------------ Callback ------------ - callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - progress_bar=None, - ): - batch_size, num_channel, num_frmaes, height, width = latents.shape - latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frmaes, num_channel, height, width) - for _ in range(stage2_num_stages - 1): - height //= 2 - width //= 2 - latents = ( - F.interpolate( - latents, - size=(height, width), - mode="bilinear", - ) - * 2 - ) - latents = latents.reshape(batch_size, num_frmaes, num_channel, height, width).permute(0, 2, 1, 3, 4) - - batch_size = latents.shape[0] - start_point_list = None - if use_dmd: - start_point_list = [latents] - - i = 0 - for i_s in range(stage2_num_stages): - patch_size = self.transformer.config.patch_size - image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( - patch_size[0] * patch_size[1] * patch_size[2] - ) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) - self.scheduler.set_timesteps( - stage2_num_inference_steps_list[i_s], - i_s, - device=device, - mu=mu, - is_amplify_first_chunk=is_amplify_first_chunk, - ) - timesteps = self.scheduler.timesteps - - if i_s > 0: - height *= 2 - width *= 2 - num_frames = latents.shape[2] - latents = latents.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frmaes, num_channel, height // 2, width // 2 - ) - latents = F.interpolate(latents, size=(height, width), mode="nearest") - latents = latents.reshape(batch_size, num_frmaes, num_channel, height, width).permute(0, 2, 1, 3, 4) - # Fix the stage - ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal - gamma = self.scheduler.config.gamma - alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) - beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) - - batch_size, channel, num_frames, height, width = latents.shape - noise = self.sample_block_noise(batch_size, channel, num_frames, height, width) - noise = noise.to(device=device, dtype=transformer_dtype) - latents = alpha * latents + beta * noise # To fix the block artifact - - if use_dmd: - start_point_list.append(latents) - - for idx, t in enumerate(timesteps): - timestep = t.expand(latents.shape[0]).to(torch.int64) - - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( - hidden_states=latents.to(transformer_dtype), - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), - )[0] - - if self.do_classifier_free_guidance: - with self.transformer.cache_context("cond_uncond"): - noise_uncond = self.transformer( - hidden_states=latents.to(transformer_dtype), - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - indices_hidden_states=indices_hidden_states, - indices_latents_history_short=indices_latents_history_short, - indices_latents_history_mid=indices_latents_history_mid, - indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), - )[0] - - if use_cfg_zero_star: - noise_pred_text = noise_pred - positive_flat = noise_pred_text.view(batch_size, -1) - negative_flat = noise_uncond.view(batch_size, -1) - - alpha = optimized_scale(positive_flat, negative_flat) - alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) - alpha = alpha.to(noise_pred_text.dtype) - - if (i_s == 0 and idx <= zero_steps) and use_zero_init: - noise_pred = noise_pred_text * 0.0 - else: - noise_pred = noise_uncond * alpha + guidance_scale * ( - noise_pred_text - noise_uncond * alpha - ) - else: - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - if isinstance(self.scheduler, HeliosScheduler): - latents = self.scheduler.step( - noise_pred, - t, - latents, - generator=generator, - return_dict=False, - cur_sampling_step=idx, - dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None, - dmd_sigmas=self.scheduler.sigmas, - dmd_timesteps=self.scheduler.timesteps, - all_timesteps=timesteps, - )[0] - else: - latents = self.scheduler.step( - noise_pred, - t, - latents, - return_dict=False, - )[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - i += 1 - - return latents - @property def guidance_scale(self): return self._guidance_scale @@ -835,18 +530,7 @@ def __call__( history_sizes: list = [16, 2, 1], num_latent_frames_per_chunk: int = 9, keep_first_frame: bool = True, - # ------------ Stage 2 ------------ - is_enable_stage2: bool = False, - stage2_num_stages: int = 3, - stage2_num_inference_steps_list: list = [10, 10, 10], - # ------------ CFG Zero ------------ - use_cfg_zero_star: bool | None = False, - use_zero_init: bool | None = True, - zero_steps: int | None = 1, - # ------------ DMD ------------ - use_dmd: bool = False, is_skip_first_chunk: bool = False, - is_amplify_first_chunk: bool = False, ): r""" The call function to the pipeline for generation. @@ -1210,6 +894,11 @@ def __call__( :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + latents = self.prepare_latents( batch_size, num_channels_latents, @@ -1222,89 +911,89 @@ def __call__( latents=None, ) - if not is_enable_stage2: - patch_size = self.transformer.config.patch_size - image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( - patch_size[0] * patch_size[1] * patch_size[2] - ) - sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) - self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) - timesteps = self.scheduler.timesteps - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - else: - num_inference_steps = ( - sum(stage2_num_inference_steps_list) * 2 - if is_amplify_first_chunk and use_dmd and is_first_chunk - else sum(stage2_num_inference_steps_list) - ) + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) + timesteps = self.scheduler.timesteps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: - if is_enable_stage2: - latents = self.stage2_sample( - latents=latents, - stage2_num_stages=stage2_num_stages, - stage2_num_inference_steps_list=stage2_num_inference_steps_list, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale, - indices_hidden_states=indices_hidden_states.unsqueeze(0), - indices_latents_history_short=indices_latents_history_short.unsqueeze(0), - indices_latents_history_mid=indices_latents_history_mid.unsqueeze(0), - indices_latents_history_long=indices_latents_history_long.unsqueeze(0), - latents_history_short=latents_history_short, - latents_history_mid=latents_history_mid, - latents_history_long=latents_history_long, - attention_kwargs=attention_kwargs, - device=device, - transformer_dtype=transformer_dtype, - # ------------ CFG Zero ------------ - use_cfg_zero_star=use_cfg_zero_star, - use_zero_init=use_zero_init, - zero_steps=zero_steps, - # -------------- DMD -------------- - use_dmd=use_dmd, - is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, - # ------------ Callback ------------ - callback_on_step_end=callback_on_step_end, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - progress_bar=progress_bar, - ) - else: - latents = self.stage1_sample( - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - timesteps=timesteps, - guidance_scale=guidance_scale, - indices_hidden_states=indices_hidden_states.unsqueeze(0), - indices_latents_history_short=indices_latents_history_short.unsqueeze(0), - indices_latents_history_mid=indices_latents_history_mid.unsqueeze(0), - indices_latents_history_long=indices_latents_history_long.unsqueeze(0), - latents_history_short=latents_history_short, - latents_history_mid=latents_history_mid, - latents_history_long=latents_history_long, - attention_kwargs=attention_kwargs, - device=device, - transformer_dtype=transformer_dtype, + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + latent_model_input = latents.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short.to(transformer_dtype), + latents_history_mid=latents_history_mid.to(transformer_dtype), + latents_history_long=latents_history_long.to(transformer_dtype), + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, generator=generator, - num_warmup_steps=num_warmup_steps, - # ------------ CFG Zero ------------ - use_cfg_zero_star=use_cfg_zero_star, - use_zero_init=use_zero_init, - zero_steps=zero_steps, - # ------------ Callback ------------ - callback_on_step_end=callback_on_step_end, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - progress_bar=progress_bar, - ) + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() if keep_first_frame and ( (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py new file mode 100644 index 000000000000..2d69303fec91 --- /dev/null +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -0,0 +1,1134 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +from itertools import accumulate +from typing import Any, Callable + +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPyramidPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPyramidPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def optimized_scale(positive_flat, negative_flat): + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + return st_star + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`HeliosScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: HeliosScheduler, + transformer: HeliosTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + negative_prompt_attention_mask = None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + use_interpolate_prompt=False, + num_videos_per_prompt=None, + interpolate_time_list=None, + interpolation_steps=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if use_interpolate_prompt: + assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" + assert isinstance(prompt, list), "prompt must be a list" + assert len(prompt) == len(interpolate_time_list), ( + f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" + ) + assert min(interpolate_time_list) > interpolation_steps, ( + f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 384, + width: int = 640, + num_frames: int = 33, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + fake_video = image.repeat(1, 1, 33, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks - 1, -1, -1): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.insert(0, chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + def interpolate_prompt_embeds( + self, + prompt_embeds_1: torch.Tensor, + prompt_embeds_2: torch.Tensor, + interpolation_steps: int = 3, + ): + x = torch.lerp( + prompt_embeds_1, + prompt_embeds_2, + torch.linspace(0, 1, steps=interpolation_steps).unsqueeze(1).unsqueeze(2).to(prompt_embeds_1), + ) + interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) + return interpolated_prompt_embeds + + def sample_block_noise(self, batch_size, channel, num_frames, height, width): + gamma = self.scheduler.config.gamma + cov = torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma + dist = torch.distributions.MultivariateNormal(torch.zeros(4, device=cov.device), covariance_matrix=cov) + block_number = batch_size * channel * num_frames * (height // 2) * (width // 2) + + noise = dist.sample((block_number,)) # [block number, 4] + noise = noise.view(batch_size, channel, num_frames, height // 2, width // 2, 2, 2) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 384, + width: int = 640, + num_frames: int = 132, + num_inference_steps: int = 50, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Interactive ------------ + use_interpolate_prompt: bool = False, + interpolate_time_list: list = [7, 7, 7], + interpolation_steps: int = 3, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + num_latent_frames_per_chunk: int = 9, + keep_first_frame: bool = True, + is_skip_first_chunk: bool = False, + # ------------ Stage 2 ------------ + stage2_num_stages: int = 3, + stage2_num_inference_steps_list: list = [10, 10, 10], + # ------------ CFG Zero ------------ + use_cfg_zero_star: bool | None = False, + use_zero_init: bool | None = True, + zero_steps: int | None = 1, + # ------------ DMD ------------ + use_dmd: bool = False, + is_amplify_first_chunk: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `384`): + The height in pixels of the generated image. + width (`int`, defaults to `640`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `132`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(self.vae.device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + self.vae.device, self.vae.dtype + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + use_interpolate_prompt, + num_videos_per_prompt, + interpolate_time_list, + interpolation_steps, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + # 2. Define call parameters + if use_interpolate_prompt or (prompt is not None and isinstance(prompt, str)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + if use_interpolate_prompt: + interpolate_interval_idx = None + interpolate_embeds = None + interpolate_cumulative_list = list(accumulate(interpolate_time_list)) + + all_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + ) + + transformer_dtype = self.transformer.dtype + all_prompt_embeds = all_prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + if use_interpolate_prompt: + negative_prompt_embeds = negative_prompt_embeds[0].unsqueeze(0) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image or video + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + history_video = None + total_generated_latent_frames = 0 + + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents, fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + # 6. Denoising loop + if use_interpolate_prompt: + if num_latent_chunk < max(interpolate_cumulative_list): + num_latent_chunk = sum(interpolate_cumulative_list) + print(f"Update num_latent_chunk to: {num_latent_chunk}") + + for k in range(num_latent_chunk): + if use_interpolate_prompt: + assert num_latent_chunk >= max(interpolate_cumulative_list) + + current_interval_idx = 0 + for idx, cumulative_val in enumerate(interpolate_cumulative_list): + if k < cumulative_val: + current_interval_idx = idx + break + + if current_interval_idx == 0: + prompt_embeds = all_prompt_embeds[0].unsqueeze(0) + else: + interval_start = interpolate_cumulative_list[current_interval_idx - 1] + position_in_interval = k - interval_start + + if position_in_interval < interpolation_steps: + if interpolate_embeds is None or interpolate_interval_idx != current_interval_idx: + interpolate_embeds = self.interpolate_prompt_embeds( + prompt_embeds_1=all_prompt_embeds[current_interval_idx - 1].unsqueeze(0), + prompt_embeds_2=all_prompt_embeds[current_interval_idx].unsqueeze(0), + interpolation_steps=interpolation_steps, + ) + interpolate_interval_idx = current_interval_idx + + prompt_embeds = interpolate_embeds[position_in_interval] + else: + prompt_embeds = all_prompt_embeds[current_interval_idx].unsqueeze(0) + else: + prompt_embeds = all_prompt_embeds + + is_first_chunk = k == 0 + is_second_chunk = k == 1 + if keep_first_frame: + if is_first_chunk: + history_sizes_first_chunk = [1] + history_sizes.copy() + history_latents_first_chunk = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes_first_chunk), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents_first_chunk = torch.cat( + [history_latents_first_chunk, fake_image_latents], dim=2 + ) + if video_latents is not None: + history_frames = history_latents_first_chunk.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents_first_chunk = torch.cat( + [history_latents_first_chunk[:, :, :keep_frames, :, :], video_latents], dim=2 + ) + else: + history_latents_first_chunk = video_latents + + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix, latents_history_long, latents_history_mid, latents_history_1x = ( + history_latents_first_chunk[:, :, -sum(history_sizes_first_chunk) :].split( + history_sizes_first_chunk, dim=2 + ) + ) + if image_latents is not None: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + num_inference_steps = ( + sum(stage2_num_inference_steps_list) * 2 + if is_amplify_first_chunk and use_dmd and is_first_chunk + else sum(stage2_num_inference_steps_list) + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + batch_size, num_channel, num_frmaes, pyramid_height, pyramid_width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frmaes, num_channel, pyramid_height, pyramid_width + ) + for _ in range(stage2_num_stages - 1): + pyramid_height //= 2 + pyramid_width //= 2 + latents = ( + F.interpolate( + latents, + size=(pyramid_height, pyramid_width), + mode="bilinear", + ) + * 2 + ) + latents = latents.reshape(batch_size, num_frmaes, num_channel, pyramid_height, pyramid_width).permute( + 0, 2, 1, 3, 4 + ) + + start_point_list = None + if use_dmd: + start_point_list = [latents] + + for i_s in range(stage2_num_stages): + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.set_timesteps( + stage2_num_inference_steps_list[i_s], + i_s, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, + ) + timesteps = self.scheduler.timesteps + num_warmup_steps = 0 + self._num_timesteps = len(timesteps) + + if i_s > 0: + pyramid_height *= 2 + pyramid_width *= 2 + num_frames = latents.shape[2] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frmaes, num_channel, pyramid_height // 2, pyramid_width // 2 + ) + latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") + latents = latents.reshape( + batch_size, num_frmaes, num_channel, pyramid_height, pyramid_width + ).permute(0, 2, 1, 3, 4) + # Fix the stage + ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal + gamma = self.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape + noise = self.sample_block_noise(batch_size, channel, num_frames, pyramid_height, pyramid_width) + noise = noise.to(device=device, dtype=transformer_dtype) + latents = alpha * latents + beta * noise # To fix the block artifact + + if use_dmd: + start_point_list.append(latents) + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + + latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if use_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) + alpha = alpha.to(noise_pred_text.dtype) + + if (i_s == 0 and i <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * ( + noise_pred_text - noise_uncond * alpha + ) + else: + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + return_dict=False, + cur_sampling_step=i, + dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None, + dmd_sigmas=self.scheduler.sigmas, + dmd_timesteps=self.scheduler.timesteps, + all_timesteps=timesteps, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if keep_first_frame and ( + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + index_slice = ( + slice(None), + slice(None), + slice(-num_latent_frames_per_chunk, None), + ) + + current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 691d13943117..1d3a16ed12f2 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -195,7 +195,7 @@ def _sigma_to_t(self, sigma): def set_timesteps( self, num_inference_steps: int, - stage_index: int| None = None, + stage_index: int | None = None, device: str | torch.device = None, sigmas: bool | None = None, mu: bool | None = None, @@ -215,10 +215,12 @@ def set_timesteps( if self.config.stages == 1: if sigmas is None: - sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(np.float32) + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( + np.float32 + ) if self.config.shift != 1.0: assert not self.config.use_dynamic_shifting - sigmas = self.time_shift(shift, 1.0, sigmas) + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) timesteps = (sigmas * self.config.num_train_timesteps).copy() sigmas = torch.from_numpy(sigmas) else: diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a5830c9f0fd7..b86b5d2c6f4d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1367,6 +1367,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HeliosPyramidPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 5a87c0c3958e7a2040e2d14c842d50da2b034559 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sat, 28 Feb 2026 10:52:53 +0800 Subject: [PATCH 045/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index a928b5eabf9e..755cb6169b56 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -657,7 +657,7 @@ def forward( # 2. Process noisy latents hidden_states = self.patch_embedding(hidden_states) - B, C, T, H, W = hidden_states.shape + _, original_context_length, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape if indices_hidden_states is None: indices_hidden_states = torch.arange(0, T).unsqueeze(0).expand(B, -1) From a573460add3d8e802d9cacd6cd656019f98441c9 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sat, 28 Feb 2026 10:53:25 +0800 Subject: [PATCH 046/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 755cb6169b56..9c8afa3e67c4 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -671,10 +671,6 @@ def forward( ) rotary_emb = rotary_emb.flatten(2).transpose(1, 2) - post_patch_height = H - post_patch_width = W - post_patch_num_frames = T - original_context_length = hidden_states.shape[1] # 3. Process short history latents if latents_history_short is not None and indices_latents_history_short is not None: From 94988e8bbef09b892da18055e6e0a78bdb856d17 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sat, 28 Feb 2026 10:53:35 +0800 Subject: [PATCH 047/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 9c8afa3e67c4..286ae2b38a75 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -665,7 +665,7 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) rotary_emb = self.rope( frame_indices=indices_hidden_states, - height=H, + height=post_patch_height, width=W, device=hidden_states.device, ) From 0171b3f003ac0da15efeeef907ae665f8c52c48e Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sat, 28 Feb 2026 10:53:44 +0800 Subject: [PATCH 048/107] Update src/diffusers/models/transformers/transformer_helios.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 286ae2b38a75..bdbee2f2c89e 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -666,7 +666,7 @@ def forward( rotary_emb = self.rope( frame_indices=indices_hidden_states, height=post_patch_height, - width=W, + width=post_patch_width device=hidden_states.device, ) rotary_emb = rotary_emb.flatten(2).transpose(1, 2) From e4efddf80e639bca48e9a46f8a1ee2511f56a128 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sat, 28 Feb 2026 03:09:49 +0000 Subject: [PATCH 049/107] fix transformer --- src/diffusers/models/transformers/transformer_helios.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index bdbee2f2c89e..3c9eb1fbb5f7 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -657,20 +657,20 @@ def forward( # 2. Process noisy latents hidden_states = self.patch_embedding(hidden_states) - _, original_context_length, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape + _, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape if indices_hidden_states is None: - indices_hidden_states = torch.arange(0, T).unsqueeze(0).expand(B, -1) + indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1) hidden_states = hidden_states.flatten(2).transpose(1, 2) rotary_emb = self.rope( frame_indices=indices_hidden_states, height=post_patch_height, - width=post_patch_width + width=post_patch_width, device=hidden_states.device, ) rotary_emb = rotary_emb.flatten(2).transpose(1, 2) - + original_context_length = hidden_states.shape[1] # 3. Process short history latents if latents_history_short is not None and indices_latents_history_short is not None: From e1bb744251f9a1b5c1ce2c396da078f79347d31d Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sat, 28 Feb 2026 03:15:44 +0000 Subject: [PATCH 050/107] use a more intuitive name --- 0_temp_helios_test/infer_helios.py | 9 ++++--- 0_temp_helios_test/stage-2_i2v.sh | 4 ++-- 0_temp_helios_test/stage-2_t2v.sh | 4 ++-- 0_temp_helios_test/stage-2_v2v.sh | 4 ++-- 0_temp_helios_test/stage-3_i2v.sh | 4 ++-- 0_temp_helios_test/stage-3_t2v.sh | 4 ++-- 0_temp_helios_test/stage-3_v2v.sh | 4 ++-- docs/source/en/api/pipelines/helios.md | 24 +++++++++---------- .../helios/pipeline_helios_pyramid.py | 16 ++++++------- 9 files changed, 36 insertions(+), 37 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 37383f6c55ec..f30eda5e928a 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -72,11 +72,11 @@ def parse_args(): parser.add_argument("--num_latent_frames_per_chunk", type=int, default=9) # stage 2 parser.add_argument("--is_enable_stage2", action="store_true") - parser.add_argument("--stage2_num_stages", type=int, default=3) + parser.add_argument("--pyramid_num_stages", type=int, default=3) parser.add_argument("--stage2_timestep_shift", type=float, default=1.0) parser.add_argument("--stage2_scheduler_gamma", type=float, default=1 / 3) parser.add_argument("--stage2_stage_range", type=int, nargs="+", default=[0, 1 / 3, 2 / 3, 1]) - parser.add_argument("--stage2_num_inference_steps_list", type=int, nargs="+", default=[20, 20, 20]) + parser.add_argument("--pyramid_num_inference_steps_list", type=int, nargs="+", default=[20, 20, 20]) # stage 3 parser.add_argument("--is_enable_stage3", action="store_true") parser.add_argument("--is_skip_first_chunk", action="store_true") @@ -288,10 +288,9 @@ def main(): keep_first_frame=True, is_skip_first_chunk=args.is_skip_first_chunk, # stage 2 - stage2_num_stages=args.stage2_num_stages, - stage2_num_inference_steps_list=args.stage2_num_inference_steps_list, + pyramid_num_stages=args.pyramid_num_stages, + pyramid_num_inference_steps_list=args.pyramid_num_inference_steps_list, # stage 3 - use_dmd=args.is_enable_stage3, is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero use_cfg_zero_star=args.use_cfg_zero_star, diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 9e235e7a5dbd..73e8f61e2271 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -5,13 +5,13 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 20 20 20 \ + --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" - # --stage2_num_inference_steps_list 17 17 17 \ + # --pyramid_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index 36590dc1fce1..706113311d08 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -4,13 +4,13 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 20 20 20 \ + --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" - # --stage2_num_inference_steps_list 17 17 17 \ + # --pyramid_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 8f4181af607c..1fdb5d25daee 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -5,13 +5,13 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 20 20 20 \ + --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" - # --stage2_num_inference_steps_list 17 17 17 \ + # --pyramid_num_inference_steps_list 17 17 17 \ # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index f1adf67d337c..9a05df9437c7 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -7,11 +7,11 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --num_frames 240 \ --guidance_scale 1.0 \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 2 2 2 \ + --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --stage2_num_inference_steps_list 1 1 1 \ + # --pyramid_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/stage-3_t2v.sh index 1f1efc04df41..02d8ca33dbf8 100644 --- a/0_temp_helios_test/stage-3_t2v.sh +++ b/0_temp_helios_test/stage-3_t2v.sh @@ -6,11 +6,11 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --num_frames 240 \ --guidance_scale 1.0 \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 2 2 2 \ + --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --stage2_num_inference_steps_list 1 1 1 \ + # --pyramid_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index 11a1a5fb6fcf..9193a4ccba79 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -7,11 +7,11 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --num_frames 240 \ --guidance_scale 1.0 \ --is_enable_stage2 \ - --stage2_num_inference_steps_list 2 2 2 \ + --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ --output_folder "./output_helios/stage-3" - # --stage2_num_inference_steps_list 1 1 1 \ + # --pyramid_num_inference_steps_list 1 1 1 \ # --enable_compile \ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 50a65af08b72..952dd77a6cf5 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -281,8 +281,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[20, 20, 20], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[20, 20, 20], use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, @@ -307,8 +307,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[20, 20, 20], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[20, 20, 20], use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, @@ -332,8 +332,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[20, 20, 20], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[20, 20, 20], use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, @@ -387,8 +387,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[2, 2, 2], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[2, 2, 2], use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, @@ -413,8 +413,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[2, 2, 2], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[2, 2, 2], use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, @@ -438,8 +438,8 @@ output = pipeline( num_frames=99, use_dynamic_shifting=True, is_enable_stage2=True, - stage2_num_stages=3, - stage2_num_inference_steps_list=[2, 2, 2], + pyramid_num_stages=3, + pyramid_num_inference_steps_list=[2, 2, 2], use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 2d69303fec91..a793c7628218 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -533,14 +533,13 @@ def __call__( keep_first_frame: bool = True, is_skip_first_chunk: bool = False, # ------------ Stage 2 ------------ - stage2_num_stages: int = 3, - stage2_num_inference_steps_list: list = [10, 10, 10], + pyramid_num_stages: int = 3, + pyramid_num_inference_steps_list: list = [10, 10, 10], # ------------ CFG Zero ------------ use_cfg_zero_star: bool | None = False, use_zero_init: bool | None = True, zero_steps: int | None = 1, # ------------ DMD ------------ - use_dmd: bool = False, is_amplify_first_chunk: bool = False, ): r""" @@ -613,6 +612,7 @@ def __call__( raise ValueError("image and video cannot be provided simultaneously") history_sizes = sorted(history_sizes, reverse=True) # From big to small + use_dmd = True if self.scheduler.scheduler_type == "dmd" else False latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -923,9 +923,9 @@ def __call__( ) num_inference_steps = ( - sum(stage2_num_inference_steps_list) * 2 + sum(pyramid_num_inference_steps_list) * 2 if is_amplify_first_chunk and use_dmd and is_first_chunk - else sum(stage2_num_inference_steps_list) + else sum(pyramid_num_inference_steps_list) ) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -933,7 +933,7 @@ def __call__( latents = latents.permute(0, 2, 1, 3, 4).reshape( batch_size * num_frmaes, num_channel, pyramid_height, pyramid_width ) - for _ in range(stage2_num_stages - 1): + for _ in range(pyramid_num_stages - 1): pyramid_height //= 2 pyramid_width //= 2 latents = ( @@ -952,7 +952,7 @@ def __call__( if use_dmd: start_point_list = [latents] - for i_s in range(stage2_num_stages): + for i_s in range(pyramid_num_stages): patch_size = self.transformer.config.patch_size image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( patch_size[0] * patch_size[1] * patch_size[2] @@ -965,7 +965,7 @@ def __call__( self.scheduler.config.get("max_shift", 1.15), ) self.scheduler.set_timesteps( - stage2_num_inference_steps_list[i_s], + pyramid_num_inference_steps_list[i_s], i_s, device=device, mu=mu, From 7a1785ecf4a0fd5e646c403151ee4c29b9932a62 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sat, 28 Feb 2026 04:34:39 +0000 Subject: [PATCH 051/107] update example script --- 0_temp_helios_test/stage-1_i2v.sh | 1 + 0_temp_helios_test/stage-1_t2v.sh | 1 + 0_temp_helios_test/stage-1_v2v.sh | 1 + 0_temp_helios_test/stage-2_i2v.sh | 1 + 0_temp_helios_test/stage-2_t2v.sh | 1 + 0_temp_helios_test/stage-2_v2v.sh | 1 + 6 files changed, 6 insertions(+) diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/stage-1_i2v.sh index 968fb82a41b0..664892db3c14 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/stage-1_i2v.sh @@ -3,6 +3,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ + --guidance_scale 5.0 \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-1_t2v.sh b/0_temp_helios_test/stage-1_t2v.sh index 43756af73580..9932ccdbbee1 100644 --- a/0_temp_helios_test/stage-1_t2v.sh +++ b/0_temp_helios_test/stage-1_t2v.sh @@ -3,6 +3,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --guidance_scale 5.0 \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/stage-1_v2v.sh index 25f882853f37..626403936e69 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/stage-1_v2v.sh @@ -4,6 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --guidance_scale 5.0 \ --output_folder "./output_helios/stage-1" diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 73e8f61e2271..911d92d747f0 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -4,6 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "i2v" \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ + --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index 706113311d08..eacf82d69cf1 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -3,6 +3,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "t2v" \ --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index 1fdb5d25daee..b0b41ddb89ab 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -4,6 +4,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --sample_type "v2v" \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ --use_cfg_zero_star \ From f811ce268132541b749ca51295ae43f7a7d2156f Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Sat, 28 Feb 2026 09:04:45 +0000 Subject: [PATCH 052/107] fix requirements --- 0_temp_helios_test/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/0_temp_helios_test/requirements.txt b/0_temp_helios_test/requirements.txt index d3c12e60923a..b381119d3944 100644 --- a/0_temp_helios_test/requirements.txt +++ b/0_temp_helios_test/requirements.txt @@ -15,6 +15,7 @@ huggingface-hub==1.4.1 zstandard==0.25.0 wandb==0.23.0 video-reader-rs==0.4.1 +numpy<2.0.0 opencv-python gradio spaces From 926b0d6ae01da9b89fe8555c57a4a57f3f63a2d1 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Sat, 28 Feb 2026 12:20:05 +0000 Subject: [PATCH 053/107] remove redudant attention mask --- src/diffusers/pipelines/helios/pipeline_helios.py | 8 ++++---- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 43d7732cf10f..45bf8997daf2 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -251,7 +251,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt_embeds, _ = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -276,7 +276,7 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -284,7 +284,7 @@ def encode_prompt( dtype=dtype, ) - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, negative_prompt_embeds def check_inputs( self, @@ -654,7 +654,7 @@ def __call__( interpolate_embeds = None interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - all_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + all_prompt_embeds, negative_prompt_embeds = ( self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index a793c7628218..5a7e2e9eab02 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -252,7 +252,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt_embeds, _ = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -277,7 +277,7 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -285,7 +285,7 @@ def encode_prompt( dtype=dtype, ) - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, negative_prompt_embeds def check_inputs( self, @@ -665,7 +665,7 @@ def __call__( interpolate_embeds = None interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - all_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + all_prompt_embeds, negative_prompt_embeds = ( self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, From 244d994dbbb8fe24b36d58c1c40fdf6a74dcf2f5 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Sat, 28 Feb 2026 12:21:51 +0000 Subject: [PATCH 054/107] fix --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 5a7e2e9eab02..13677e251d75 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -612,7 +612,7 @@ def __call__( raise ValueError("image and video cannot be provided simultaneously") history_sizes = sorted(history_sizes, reverse=True) # From big to small - use_dmd = True if self.scheduler.scheduler_type == "dmd" else False + use_dmd = True if self.scheduler.config.scheduler_type == "dmd" else False latents_mean = ( torch.tensor(self.vae.config.latents_mean) From c8f571dac8adfff526c84c2f7cb54e1905d8739d Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sat, 28 Feb 2026 13:17:39 +0000 Subject: [PATCH 055/107] optimize pipelines --- 0_temp_helios_test/stage-3_i2v.sh | 2 +- 0_temp_helios_test/stage-3_v2v.sh | 2 +- .../pipelines/helios/pipeline_helios.py | 125 ++++++------------ .../helios/pipeline_helios_pyramid.py | 110 +++++---------- 4 files changed, 80 insertions(+), 159 deletions(-) diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/stage-3_i2v.sh index 9a05df9437c7..e903a1475e2f 100644 --- a/0_temp_helios_test/stage-3_i2v.sh +++ b/0_temp_helios_test/stage-3_i2v.sh @@ -1,4 +1,4 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ +CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Distilled" \ --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "i2v" \ diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/stage-3_v2v.sh index 9193a4ccba79..24fbacf64312 100644 --- a/0_temp_helios_test/stage-3_v2v.sh +++ b/0_temp_helios_test/stage-3_v2v.sh @@ -1,4 +1,4 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ +CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Distilled" \ --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "v2v" \ diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 45bf8997daf2..7fd4d4a72f1a 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -782,11 +782,47 @@ def __call__( history_latents = video_latents total_generated_latent_frames += video_latents.shape[2] + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + # 6. Denoising loop if use_interpolate_prompt: if num_latent_chunk < max(interpolate_cumulative_list): num_latent_chunk = sum(interpolate_cumulative_list) print(f"Update num_latent_chunk to: {num_latent_chunk}") + + patch_size = self.transformer.config.patch_size + image_seq_len = num_latent_frames_per_chunk * (height // self.vae_scale_factor_spatial) * (width // self.vae_scale_factor_spatial) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) for k in range(num_latent_chunk): if use_interpolate_prompt: @@ -822,83 +858,20 @@ def __call__( is_first_chunk = k == 0 is_second_chunk = k == 1 if keep_first_frame: - if is_first_chunk: - history_sizes_first_chunk = [1] + history_sizes.copy() - history_latents_first_chunk = torch.zeros( - batch_size, - num_channels_latents, - sum(history_sizes_first_chunk), - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - device=device, - dtype=torch.float32, - ) - if fake_image_latents is not None: - history_latents_first_chunk = torch.cat( - [history_latents_first_chunk, fake_image_latents], dim=2 - ) - if video_latents is not None: - history_frames = history_latents_first_chunk.shape[2] - video_frames = video_latents.shape[2] - if video_frames < history_frames: - keep_frames = history_frames - video_frames - history_latents_first_chunk = torch.cat( - [history_latents_first_chunk[:, :, :keep_frames, :, :], video_latents], dim=2 - ) - else: - history_latents_first_chunk = video_latents - - indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) - ( - indices_prefix, - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_1x, - indices_hidden_states, - ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) - indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) - - latents_prefix, latents_history_long, latents_history_mid, latents_history_1x = ( - history_latents_first_chunk[:, :, -sum(history_sizes_first_chunk) :].split( - history_sizes_first_chunk, dim=2 - ) - ) - if image_latents is not None: - latents_prefix = image_latents - latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros((batch_size, num_channels_latents, 1, latents_history_1x.shape[-2], latents_history_1x.shape[-1]), + device=latents_history_1x.device, dtype=latents_history_1x.dtype) else: - indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) - ( - indices_prefix, - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_1x, - indices_hidden_states, - ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) - indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) - latents_prefix = image_latents - latents_history_long, latents_history_mid, latents_history_1x = history_latents[ - :, :, -sum(history_sizes) : - ].split(history_sizes, dim=2) - latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: - indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) - ( - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_short, - indices_hidden_states, - ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) latents_history_long, latents_history_mid, latents_history_short = history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) - indices_hidden_states = indices_hidden_states.unsqueeze(0) - indices_latents_history_short = indices_latents_history_short.unsqueeze(0) - indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) - indices_latents_history_long = indices_latents_history_long.unsqueeze(0) - latents = self.prepare_latents( batch_size, num_channels_latents, @@ -911,18 +884,6 @@ def __call__( latents=None, ) - patch_size = self.transformer.config.patch_size - image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( - patch_size[0] * patch_size[1] * patch_size[2] - ) - sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) timesteps = self.scheduler.timesteps num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 13677e251d75..e716bf5569ae 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -793,6 +793,29 @@ def __call__( history_latents = video_latents total_generated_latent_frames += video_latents.shape[2] + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + # 6. Denoising loop if use_interpolate_prompt: if num_latent_chunk < max(interpolate_cumulative_list): @@ -833,83 +856,20 @@ def __call__( is_first_chunk = k == 0 is_second_chunk = k == 1 if keep_first_frame: - if is_first_chunk: - history_sizes_first_chunk = [1] + history_sizes.copy() - history_latents_first_chunk = torch.zeros( - batch_size, - num_channels_latents, - sum(history_sizes_first_chunk), - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, - device=device, - dtype=torch.float32, - ) - if fake_image_latents is not None: - history_latents_first_chunk = torch.cat( - [history_latents_first_chunk, fake_image_latents], dim=2 - ) - if video_latents is not None: - history_frames = history_latents_first_chunk.shape[2] - video_frames = video_latents.shape[2] - if video_frames < history_frames: - keep_frames = history_frames - video_frames - history_latents_first_chunk = torch.cat( - [history_latents_first_chunk[:, :, :keep_frames, :, :], video_latents], dim=2 - ) - else: - history_latents_first_chunk = video_latents - - indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) - ( - indices_prefix, - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_1x, - indices_hidden_states, - ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) - indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) - - latents_prefix, latents_history_long, latents_history_mid, latents_history_1x = ( - history_latents_first_chunk[:, :, -sum(history_sizes_first_chunk) :].split( - history_sizes_first_chunk, dim=2 - ) - ) - if image_latents is not None: - latents_prefix = image_latents - latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros((batch_size, num_channels_latents, 1, latents_history_1x.shape[-2], latents_history_1x.shape[-1]), + device=latents_history_1x.device, dtype=latents_history_1x.dtype) else: - indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) - ( - indices_prefix, - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_1x, - indices_hidden_states, - ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) - indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) - latents_prefix = image_latents - latents_history_long, latents_history_mid, latents_history_1x = history_latents[ - :, :, -sum(history_sizes) : - ].split(history_sizes, dim=2) - latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: - indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) - ( - indices_latents_history_long, - indices_latents_history_mid, - indices_latents_history_short, - indices_hidden_states, - ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) latents_history_long, latents_history_mid, latents_history_short = history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) - indices_hidden_states = indices_hidden_states.unsqueeze(0) - indices_latents_history_short = indices_latents_history_short.unsqueeze(0) - indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) - indices_latents_history_long = indices_latents_history_long.unsqueeze(0) - latents = self.prepare_latents( batch_size, num_channels_latents, @@ -929,9 +889,9 @@ def __call__( ) with self.progress_bar(total=num_inference_steps) as progress_bar: - batch_size, num_channel, num_frmaes, pyramid_height, pyramid_width = latents.shape + _, _, _, pyramid_height, pyramid_width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frmaes, num_channel, pyramid_height, pyramid_width + batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width ) for _ in range(pyramid_num_stages - 1): pyramid_height //= 2 @@ -944,7 +904,7 @@ def __call__( ) * 2 ) - latents = latents.reshape(batch_size, num_frmaes, num_channel, pyramid_height, pyramid_width).permute( + latents = latents.reshape(batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width).permute( 0, 2, 1, 3, 4 ) @@ -980,11 +940,11 @@ def __call__( pyramid_width *= 2 num_frames = latents.shape[2] latents = latents.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frmaes, num_channel, pyramid_height // 2, pyramid_width // 2 + batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height // 2, pyramid_width // 2 ) latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") latents = latents.reshape( - batch_size, num_frmaes, num_channel, pyramid_height, pyramid_width + batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width ).permute(0, 2, 1, 3, 4) # Fix the stage ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal From 50b565af2edd0574aa6678aec2d61268dd85b4a2 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sat, 28 Feb 2026 13:18:35 +0000 Subject: [PATCH 056/107] make style . --- .../pipelines/helios/pipeline_helios.py | 43 +++++++++------- .../helios/pipeline_helios_pyramid.py | 51 ++++++++++++------- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 7fd4d4a72f1a..2314d95c899b 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -259,7 +259,6 @@ def encode_prompt( dtype=dtype, ) - negative_prompt_attention_mask = None if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -654,17 +653,15 @@ def __call__( interpolate_embeds = None interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - all_prompt_embeds, negative_prompt_embeds = ( - self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - device=device, - ) + all_prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, ) transformer_dtype = self.transformer.dtype @@ -810,10 +807,13 @@ def __call__( if num_latent_chunk < max(interpolate_cumulative_list): num_latent_chunk = sum(interpolate_cumulative_list) print(f"Update num_latent_chunk to: {num_latent_chunk}") - + patch_size = self.transformer.config.patch_size - image_seq_len = num_latent_frames_per_chunk * (height // self.vae_scale_factor_spatial) * (width // self.vae_scale_factor_spatial) // ( - patch_size[0] * patch_size[1] * patch_size[2] + image_seq_len = ( + num_latent_frames_per_chunk + * (height // self.vae_scale_factor_spatial) + * (width // self.vae_scale_factor_spatial) + // (patch_size[0] * patch_size[1] * patch_size[2]) ) sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas mu = calculate_shift( @@ -862,8 +862,17 @@ def __call__( :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) if image_latents is None and is_first_chunk: - latents_prefix = torch.zeros((batch_size, num_channels_latents, 1, latents_history_1x.shape[-2], latents_history_1x.shape[-1]), - device=latents_history_1x.device, dtype=latents_history_1x.dtype) + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=latents_history_1x.device, + dtype=latents_history_1x.dtype, + ) else: latents_prefix = image_latents latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index e716bf5569ae..7db36348d8e5 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -260,7 +260,6 @@ def encode_prompt( dtype=dtype, ) - negative_prompt_attention_mask = None if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -665,17 +664,15 @@ def __call__( interpolate_embeds = None interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - all_prompt_embeds, negative_prompt_embeds = ( - self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - device=device, - ) + all_prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, ) transformer_dtype = self.transformer.dtype @@ -860,8 +857,17 @@ def __call__( :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) if image_latents is None and is_first_chunk: - latents_prefix = torch.zeros((batch_size, num_channels_latents, 1, latents_history_1x.shape[-2], latents_history_1x.shape[-1]), - device=latents_history_1x.device, dtype=latents_history_1x.dtype) + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=latents_history_1x.device, + dtype=latents_history_1x.dtype, + ) else: latents_prefix = image_latents latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) @@ -904,9 +910,9 @@ def __call__( ) * 2 ) - latents = latents.reshape(batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width).permute( - 0, 2, 1, 3, 4 - ) + latents = latents.reshape( + batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + ).permute(0, 2, 1, 3, 4) start_point_list = None if use_dmd: @@ -940,11 +946,18 @@ def __call__( pyramid_width *= 2 num_frames = latents.shape[2] latents = latents.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height // 2, pyramid_width // 2 + batch_size * num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height // 2, + pyramid_width // 2, ) latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") latents = latents.reshape( - batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + batch_size, + num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height, + pyramid_width, ).permute(0, 2, 1, 3, 4) # Fix the stage ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal From ca8cc7e9479c4d35e66dd2abe7c63f224ba94b85 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 04:57:31 +0000 Subject: [PATCH 057/107] update TYPE_CHECKING --- src/diffusers/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 450e69056a72..8b8d9c52659e 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -213,6 +213,7 @@ Flux2Transformer2DModel, FluxTransformer2DModel, GlmImageTransformer2DModel, + HeliosTransformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, From 7174e44a4ceaf034e44b099477943edeaa2e7dc5 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sun, 1 Mar 2026 13:06:01 +0800 Subject: [PATCH 058/107] change to use torch.split Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_helios.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 3c9eb1fbb5f7..c7d287115e72 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -462,10 +462,7 @@ def forward( if self.guidance_cross_attn: history_seq_len = hidden_states.shape[1] - original_context_length - history_hidden_states, hidden_states = ( - hidden_states[:, :history_seq_len], - hidden_states[:, history_seq_len:], - ) + history_hidden_states, hidden_states = torch.split(hidden_states, [history_seq_len, original_context_length], dim=1) norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2( norm_hidden_states, From dee65a540f6e6d278447a7be07cbca2ce944bf27 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:11:11 +0000 Subject: [PATCH 059/107] derive memory patch sizes from patch_size multiples --- .../models/transformers/transformer_helios.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index c7d287115e72..43b0327273b9 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -462,7 +462,9 @@ def forward( if self.guidance_cross_attn: history_seq_len = hidden_states.shape[1] - original_context_length - history_hidden_states, hidden_states = torch.split(hidden_states, [history_seq_len, original_context_length], dim=1) + history_hidden_states, hidden_states = torch.split( + hidden_states, [history_seq_len, original_context_length], dim=1 + ) norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2( norm_hidden_states, @@ -594,9 +596,19 @@ def __init__( self.zero_history_timestep = zero_history_timestep self.inner_dim = inner_dim if has_multi_term_memory_patch: - self.patch_short = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - self.patch_mid = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) - self.patch_long = nn.Conv3d(in_channels, self.inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.patch_short = nn.Conv3d(in_channels, self.inner_dim, kernel_size=patch_size, stride=patch_size) + self.patch_mid = nn.Conv3d( + in_channels, + self.inner_dim, + kernel_size=tuple(2 * p for p in patch_size), + stride=tuple(2 * p for p in patch_size), + ) + self.patch_long = nn.Conv3d( + in_channels, + self.inner_dim, + kernel_size=tuple(4 * p for p in patch_size), + stride=tuple(4 * p for p in patch_size), + ) # 3. Condition embeddings self.condition_embedder = HeliosTimeTextEmbedding( From d431f54d6e1cf601c7495e000ec266753959be56 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:20:14 +0000 Subject: [PATCH 060/107] remove some hardcoding --- src/diffusers/pipelines/helios/pipeline_helios.py | 9 ++++++--- .../pipelines/helios/pipeline_helios_pyramid.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 2314d95c899b..ec4040587b6d 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -377,6 +377,7 @@ def prepare_image_latents( image: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | list[torch.Generator] | None = None, @@ -389,7 +390,8 @@ def prepare_image_latents( latents = self.vae.encode(image).latent_dist.sample(generator=generator) latents = (latents - latents_mean) * latents_std if fake_latents is None: - fake_video = image.repeat(1, 1, 33, 1, 1).to(device=device, dtype=self.vae.dtype) + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) fake_latents_full = (fake_latents_full - latents_mean) * latents_std fake_latents = fake_latents_full[:, :, -1:, :, :] @@ -410,13 +412,13 @@ def prepare_video_latents( video = video.to(device=device, dtype=self.vae.dtype) if latents is None: num_frames = video.shape[2] - min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1 + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 num_chunks = num_frames // min_frames if num_chunks == 0: raise ValueError( f"Video must have at least {min_frames} frames " f"(got {num_frames} frames). " - f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}" + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" ) total_valid_frames = num_chunks * min_frames start_frame = num_frames - total_valid_frames @@ -678,6 +680,7 @@ def __call__( image, latents_mean=latents_mean, latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, dtype=torch.float32, device=device, generator=generator, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 7db36348d8e5..c2da3034e44e 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -378,6 +378,7 @@ def prepare_image_latents( image: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, dtype: torch.dtype | None = None, device: torch.device | None = None, generator: torch.Generator | list[torch.Generator] | None = None, @@ -390,7 +391,8 @@ def prepare_image_latents( latents = self.vae.encode(image).latent_dist.sample(generator=generator) latents = (latents - latents_mean) * latents_std if fake_latents is None: - fake_video = image.repeat(1, 1, 33, 1, 1).to(device=device, dtype=self.vae.dtype) + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) fake_latents_full = (fake_latents_full - latents_mean) * latents_std fake_latents = fake_latents_full[:, :, -1:, :, :] @@ -411,13 +413,13 @@ def prepare_video_latents( video = video.to(device=device, dtype=self.vae.dtype) if latents is None: num_frames = video.shape[2] - min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1 + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 num_chunks = num_frames // min_frames if num_chunks == 0: raise ValueError( f"Video must have at least {min_frames} frames " f"(got {num_frames} frames). " - f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}" + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" ) total_valid_frames = num_chunks * min_frames start_frame = num_frames - total_valid_frames @@ -689,6 +691,7 @@ def __call__( image, latents_mean=latents_mean, latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, dtype=torch.float32, device=device, generator=generator, From c618877bdc3a1b62796b72b0ffd417c65c87fae7 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:24:12 +0000 Subject: [PATCH 061/107] move some checks into check_inputs --- src/diffusers/pipelines/helios/pipeline_helios.py | 10 +++++++--- .../pipelines/helios/pipeline_helios_pyramid.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index ec4040587b6d..8e3e481d8238 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -294,6 +294,8 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + image=None, + video=None, use_interpolate_prompt=False, num_videos_per_prompt=None, interpolate_time_list=None, @@ -330,6 +332,9 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + if use_interpolate_prompt: assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" assert isinstance(prompt, list), "prompt must be a list" @@ -599,9 +604,6 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - if image is not None and video is not None: - raise ValueError("image and video cannot be provided simultaneously") - history_sizes = sorted(history_sizes, reverse=True) # From big to small latents_mean = ( @@ -625,6 +627,8 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + image, + video, use_interpolate_prompt, num_videos_per_prompt, interpolate_time_list, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index c2da3034e44e..b9b06ca6616a 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -295,6 +295,8 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + image=None, + video=None, use_interpolate_prompt=False, num_videos_per_prompt=None, interpolate_time_list=None, @@ -331,6 +333,9 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + if use_interpolate_prompt: assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" assert isinstance(prompt, list), "prompt must be a list" @@ -609,9 +614,6 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - if image is not None and video is not None: - raise ValueError("image and video cannot be provided simultaneously") - history_sizes = sorted(history_sizes, reverse=True) # From big to small use_dmd = True if self.scheduler.config.scheduler_type == "dmd" else False @@ -636,6 +638,8 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + image, + video, use_interpolate_prompt, num_videos_per_prompt, interpolate_time_list, From f25d9272cff506a7d8597d63bd91261a374a56ff Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:30:25 +0000 Subject: [PATCH 062/107] refactor sample_block_noise --- .../pipelines/helios/pipeline_helios.py | 11 ---------- .../helios/pipeline_helios_pyramid.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 8e3e481d8238..a09a8a6bba9a 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -457,17 +457,6 @@ def interpolate_prompt_embeds( interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) return interpolated_prompt_embeds - def sample_block_noise(self, batch_size, channel, num_frames, height, width): - gamma = self.scheduler.config.gamma - cov = torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma - dist = torch.distributions.MultivariateNormal(torch.zeros(4, device=cov.device), covariance_matrix=cov) - block_number = batch_size * channel * num_frames * (height // 2) * (width // 2) - - noise = dist.sample((block_number,)) # [block number, 4] - noise = noise.view(batch_size, channel, num_frames, height // 2, width // 2, 2, 2) - noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) - return noise - @property def guidance_scale(self): return self._guidance_scale diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index b9b06ca6616a..9abeaa64eb57 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -458,14 +458,19 @@ def interpolate_prompt_embeds( interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) return interpolated_prompt_embeds - def sample_block_noise(self, batch_size, channel, num_frames, height, width): + def sample_block_noise(self, batch_size, channel, num_frames, height, width, patch_size=(1, 2, 2)): gamma = self.scheduler.config.gamma - cov = torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma - dist = torch.distributions.MultivariateNormal(torch.zeros(4, device=cov.device), covariance_matrix=cov) - block_number = batch_size * channel * num_frames * (height // 2) * (width // 2) + _, ph, pw = patch_size + block_size = ph * pw - noise = dist.sample((block_number,)) # [block number, 4] - noise = noise.view(batch_size, channel, num_frames, height // 2, width // 2, 2, 2) + cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma + dist = torch.distributions.MultivariateNormal( + torch.zeros(block_size, device=cov.device), covariance_matrix=cov + ) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + + noise = dist.sample((block_number,)) # [block number, block_size] + noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) return noise @@ -973,7 +978,9 @@ def __call__( beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape - noise = self.sample_block_noise(batch_size, channel, num_frames, pyramid_height, pyramid_width) + noise = self.sample_block_noise( + batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size + ) noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact From 42d30d0051ec08e8121a2c34dc77ec874cc82470 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:42:53 +0000 Subject: [PATCH 063/107] optimize encoding chunks logits for v2v --- src/diffusers/pipelines/helios/pipeline_helios.py | 4 ++-- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index a09a8a6bba9a..33a650bb8d2f 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -433,13 +433,13 @@ def prepare_video_latents( first_frame_latent = (first_frame_latent - latents_mean) * latents_std latents_chunks = [] - for i in range(num_chunks - 1, -1, -1): + for i in range(num_chunks): chunk_start = start_frame + i * min_frames chunk_end = chunk_start + min_frames video_chunk = video[:, :, chunk_start:chunk_end, :, :] chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) chunk_latents = (chunk_latents - latents_mean) * latents_std - latents_chunks.insert(0, chunk_latents) + latents_chunks.append(chunk_latents) latents = torch.cat(latents_chunks, dim=2) return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 9abeaa64eb57..33b137435d0d 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -434,13 +434,13 @@ def prepare_video_latents( first_frame_latent = (first_frame_latent - latents_mean) * latents_std latents_chunks = [] - for i in range(num_chunks - 1, -1, -1): + for i in range(num_chunks): chunk_start = start_frame + i * min_frames chunk_end = chunk_start + min_frames video_chunk = video[:, :, chunk_start:chunk_end, :, :] chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) chunk_latents = (chunk_latents - latents_mean) * latents_std - latents_chunks.insert(0, chunk_latents) + latents_chunks.append(chunk_latents) latents = torch.cat(latents_chunks, dim=2) return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) From 04a0342ab21cd37e0d3e20c6dff33da0b364d0b0 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:44:39 +0000 Subject: [PATCH 064/107] use num_history_latent_frames = sum(history_sizes) --- src/diffusers/pipelines/helios/pipeline_helios.py | 7 ++++--- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 33a650bb8d2f..9189d917bb6e 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -748,6 +748,7 @@ def __call__( num_channels_latents = self.transformer.config.in_channels window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) history_video = None total_generated_latent_frames = 0 @@ -756,7 +757,7 @@ def __call__( history_latents = torch.zeros( batch_size, num_channels_latents, - sum(history_sizes), + num_history_latent_frames, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, device=device, @@ -855,7 +856,7 @@ def __call__( is_second_chunk = k == 1 if keep_first_frame: latents_history_long, latents_history_mid, latents_history_1x = history_latents[ - :, :, -sum(history_sizes) : + :, :, -num_history_latent_frames: ].split(history_sizes, dim=2) if image_latents is None and is_first_chunk: latents_prefix = torch.zeros( @@ -874,7 +875,7 @@ def __call__( latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: latents_history_long, latents_history_mid, latents_history_short = history_latents[ - :, :, -sum(history_sizes) : + :, :, -num_history_latent_frames: ].split(history_sizes, dim=2) latents = self.prepare_latents( diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 33b137435d0d..a4c54c78d844 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -775,6 +775,7 @@ def __call__( num_channels_latents = self.transformer.config.in_channels window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) history_video = None total_generated_latent_frames = 0 @@ -783,7 +784,7 @@ def __call__( history_latents = torch.zeros( batch_size, num_channels_latents, - sum(history_sizes), + num_history_latent_frames, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, device=device, @@ -866,7 +867,7 @@ def __call__( is_second_chunk = k == 1 if keep_first_frame: latents_history_long, latents_history_mid, latents_history_1x = history_latents[ - :, :, -sum(history_sizes) : + :, :, -num_history_latent_frames: ].split(history_sizes, dim=2) if image_latents is None and is_first_chunk: latents_prefix = torch.zeros( @@ -885,7 +886,7 @@ def __call__( latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: latents_history_long, latents_history_mid, latents_history_short = history_latents[ - :, :, -sum(history_sizes) : + :, :, -num_history_latent_frames: ].split(history_sizes, dim=2) latents = self.prepare_latents( From 971f08c333479d203e3eccafc0933bf7f25ac57f Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sun, 1 Mar 2026 13:46:47 +0800 Subject: [PATCH 065/107] Update src/diffusers/pipelines/helios/pipeline_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/helios/pipeline_helios.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 9189d917bb6e..1f3fe90194d9 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -970,13 +970,7 @@ def __call__( total_generated_latent_frames += latents.shape[2] history_latents = torch.cat([history_latents, latents], dim=2) real_history_latents = history_latents[:, :, -total_generated_latent_frames:] - index_slice = ( - slice(None), - slice(None), - slice(-num_latent_frames_per_chunk, None), - ) - - current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean + current_latents = real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + latents_mean current_video = self.vae.decode(current_latents, return_dict=False)[0] if history_video is None: From f0f99c1b6b15f813fd3a7d9a1594157c2e62b34c Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:48:22 +0000 Subject: [PATCH 066/107] remove redudant optimized_scale --- src/diffusers/pipelines/helios/pipeline_helios.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 1f3fe90194d9..a346f7e2796d 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -75,16 +75,6 @@ """ -def optimized_scale(positive_flat, negative_flat): - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - return st_star - - def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -970,7 +960,10 @@ def __call__( total_generated_latent_frames += latents.shape[2] history_latents = torch.cat([history_latents, latents], dim=2) real_history_latents = history_latents[:, :, -total_generated_latent_frames:] - current_latents = real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + latents_mean + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean + ) current_video = self.vae.decode(current_latents, return_dict=False)[0] if history_video is None: From 4e418a425f0b523c38c16eaa3632385f89c0224b Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Sun, 1 Mar 2026 13:49:28 +0800 Subject: [PATCH 067/107] Update src/diffusers/pipelines/helios/pipeline_helios_pyramid.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index a4c54c78d844..f00ec2b755d2 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -164,6 +164,7 @@ def __init__( self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # Copied from diffusers.pipelines.helios.pipeline_helios.HeliosPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: str | list[str] = None, From 70dbf19e075aae247c3628150fa3f86590bdc62e Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 05:51:25 +0000 Subject: [PATCH 068/107] use more descriptive name --- .../pipelines/helios/pipeline_helios_pyramid.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index f00ec2b755d2..b2ebfde0d20b 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -932,7 +932,7 @@ def __call__( if use_dmd: start_point_list = [latents] - for i_s in range(pyramid_num_stages): + for stage_idx in range(pyramid_num_stages): patch_size = self.transformer.config.patch_size image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( patch_size[0] * patch_size[1] * patch_size[2] @@ -945,8 +945,8 @@ def __call__( self.scheduler.config.get("max_shift", 1.15), ) self.scheduler.set_timesteps( - pyramid_num_inference_steps_list[i_s], - i_s, + pyramid_num_inference_steps_list[stage_idx], + stage_idx, device=device, mu=mu, is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, @@ -955,7 +955,7 @@ def __call__( num_warmup_steps = 0 self._num_timesteps = len(timesteps) - if i_s > 0: + if stage_idx > 0: pyramid_height *= 2 pyramid_width *= 2 num_frames = latents.shape[2] @@ -974,7 +974,7 @@ def __call__( pyramid_width, ).permute(0, 2, 1, 3, 4) # Fix the stage - ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal + ori_sigma = 1 - self.scheduler.ori_start_sigmas[stage_idx] # the original coeff of signal gamma = self.scheduler.config.gamma alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) @@ -1038,7 +1038,7 @@ def __call__( alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) alpha = alpha.to(noise_pred_text.dtype) - if (i_s == 0 and i <= zero_steps) and use_zero_init: + if (stage_idx == 0 and i <= zero_steps) and use_zero_init: noise_pred = noise_pred_text * 0.0 else: noise_pred = noise_uncond * alpha + guidance_scale * ( @@ -1054,7 +1054,7 @@ def __call__( generator=generator, return_dict=False, cur_sampling_step=i, - dmd_noisy_tensor=start_point_list[i_s] if start_point_list is not None else None, + dmd_noisy_tensor=start_point_list[stage_idx] if start_point_list is not None else None, dmd_sigmas=self.scheduler.sigmas, dmd_timesteps=self.scheduler.timesteps, all_timesteps=timesteps, From ae7817f1aaa3b56851a11c2b0448676d7bf8c38f Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 06:14:28 +0000 Subject: [PATCH 069/107] optimize history_latents --- src/diffusers/pipelines/helios/pipeline_helios.py | 2 +- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index a346f7e2796d..c11a11b50852 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -754,7 +754,7 @@ def __call__( dtype=torch.float32, ) if fake_image_latents is not None: - history_latents = torch.cat([history_latents, fake_image_latents], dim=2) + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) total_generated_latent_frames += 1 if video_latents is not None: history_frames = history_latents.shape[2] diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index b2ebfde0d20b..9121960cdd88 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -792,7 +792,7 @@ def __call__( dtype=torch.float32, ) if fake_image_latents is not None: - history_latents = torch.cat([history_latents, fake_image_latents], dim=2) + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) total_generated_latent_frames += 1 if video_latents is not None: history_frames = history_latents.shape[2] From dba5aac2b9d98624ee1f849e18a7c1baff2d39c0 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 10:50:16 +0000 Subject: [PATCH 070/107] remove not used "num_inference_steps" --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 9121960cdd88..7967e7281c60 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -508,7 +508,6 @@ def __call__( height: int = 384, width: int = 640, num_frames: int = 132, - num_inference_steps: int = 50, sigmas: list[float] = None, guidance_scale: float = 5.0, num_videos_per_prompt: int | None = 1, From 6e4d099818d82f17d7d4b1113edc82ce7977c114 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 10:53:33 +0000 Subject: [PATCH 071/107] removed redudant "pyramid_num_stages" --- 0_temp_helios_test/infer_helios.py | 2 -- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index f30eda5e928a..713dd6673547 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -72,7 +72,6 @@ def parse_args(): parser.add_argument("--num_latent_frames_per_chunk", type=int, default=9) # stage 2 parser.add_argument("--is_enable_stage2", action="store_true") - parser.add_argument("--pyramid_num_stages", type=int, default=3) parser.add_argument("--stage2_timestep_shift", type=float, default=1.0) parser.add_argument("--stage2_scheduler_gamma", type=float, default=1 / 3) parser.add_argument("--stage2_stage_range", type=int, nargs="+", default=[0, 1 / 3, 2 / 3, 1]) @@ -288,7 +287,6 @@ def main(): keep_first_frame=True, is_skip_first_chunk=args.is_skip_first_chunk, # stage 2 - pyramid_num_stages=args.pyramid_num_stages, pyramid_num_inference_steps_list=args.pyramid_num_inference_steps_list, # stage 3 is_amplify_first_chunk=args.is_amplify_first_chunk, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 7967e7281c60..fd6f16a08ad2 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -544,7 +544,6 @@ def __call__( keep_first_frame: bool = True, is_skip_first_chunk: bool = False, # ------------ Stage 2 ------------ - pyramid_num_stages: int = 3, pyramid_num_inference_steps_list: list = [10, 10, 10], # ------------ CFG Zero ------------ use_cfg_zero_star: bool | None = False, @@ -621,6 +620,7 @@ def __call__( history_sizes = sorted(history_sizes, reverse=True) # From big to small use_dmd = True if self.scheduler.config.scheduler_type == "dmd" else False + pyramid_num_stages = len(pyramid_num_inference_steps_list) latents_mean = ( torch.tensor(self.vae.config.latents_mean) From 63983a7959ecd1364548bb596fb8c34cfaa156c6 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 11:13:42 +0000 Subject: [PATCH 072/107] add "is_cfg_zero_star" and "is_distilled" to HeliosPyramidPipeline --- 0_temp_helios_test/infer_helios.py | 3 --- 0_temp_helios_test/stage-2_i2v.sh | 1 - 0_temp_helios_test/stage-2_t2v.sh | 1 - 0_temp_helios_test/stage-2_v2v.sh | 1 - .../helios/pipeline_helios_pyramid.py | 22 +++++++++++-------- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 713dd6673547..bb772fd72fed 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -65,7 +65,6 @@ def parse_args(): parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=5.0) # cfg zero - parser.add_argument("--use_cfg_zero_star", action="store_true") parser.add_argument("--use_zero_init", action="store_true") parser.add_argument("--zero_steps", type=int, default=1) # stage 1 @@ -278,7 +277,6 @@ def main(): height=args.height, width=args.width, num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), # stage 1 @@ -291,7 +289,6 @@ def main(): # stage 3 is_amplify_first_chunk=args.is_amplify_first_chunk, # cfg zero - use_cfg_zero_star=args.use_cfg_zero_star, use_zero_init=args.use_zero_init, zero_steps=args.zero_steps, # i2v diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/stage-2_i2v.sh index 911d92d747f0..53033bd38bee 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/stage-2_i2v.sh @@ -7,7 +7,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ - --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/stage-2_t2v.sh index eacf82d69cf1..b0ec7201a318 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/stage-2_t2v.sh @@ -6,7 +6,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ - --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh index b0b41ddb89ab..a9b51592087f 100644 --- a/0_temp_helios_test/stage-2_v2v.sh +++ b/0_temp_helios_test/stage-2_v2v.sh @@ -7,7 +7,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ - --use_cfg_zero_star \ --use_zero_init \ --zero_steps 1 \ --output_folder "./output_helios/stage-2" diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index fd6f16a08ad2..d3a816d8cfa1 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -150,6 +150,8 @@ def __init__( vae: AutoencoderKLWan, scheduler: HeliosScheduler, transformer: HeliosTransformer3DModel, + is_cfg_zero_star: bool = False, + is_distilled: bool = False, ): super().__init__() @@ -160,6 +162,8 @@ def __init__( transformer=transformer, scheduler=scheduler, ) + self.register_to_config(is_cfg_zero_star=is_cfg_zero_star) + self.register_to_config(is_distilled=is_distilled) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -302,6 +306,7 @@ def check_inputs( num_videos_per_prompt=None, interpolate_time_list=None, interpolation_steps=None, + guidance_scale=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -347,6 +352,9 @@ def check_inputs( f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" ) + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + def prepare_latents( self, batch_size: int, @@ -546,7 +554,6 @@ def __call__( # ------------ Stage 2 ------------ pyramid_num_inference_steps_list: list = [10, 10, 10], # ------------ CFG Zero ------------ - use_cfg_zero_star: bool | None = False, use_zero_init: bool | None = True, zero_steps: int | None = 1, # ------------ DMD ------------ @@ -567,9 +574,6 @@ def __call__( The width in pixels of the generated image. num_frames (`int`, defaults to `132`): The number of frames in the generated video. - num_inference_steps (`int`, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. @@ -619,7 +623,6 @@ def __call__( """ history_sizes = sorted(history_sizes, reverse=True) # From big to small - use_dmd = True if self.scheduler.config.scheduler_type == "dmd" else False pyramid_num_stages = len(pyramid_num_inference_steps_list) latents_mean = ( @@ -649,6 +652,7 @@ def __call__( num_videos_per_prompt, interpolate_time_list, interpolation_steps, + guidance_scale, ) num_frames = max(num_frames, 1) @@ -903,7 +907,7 @@ def __call__( num_inference_steps = ( sum(pyramid_num_inference_steps_list) * 2 - if is_amplify_first_chunk and use_dmd and is_first_chunk + if is_amplify_first_chunk and self.config.is_distilled and is_first_chunk else sum(pyramid_num_inference_steps_list) ) @@ -928,7 +932,7 @@ def __call__( ).permute(0, 2, 1, 3, 4) start_point_list = None - if use_dmd: + if self.config.is_distilled: start_point_list = [latents] for stage_idx in range(pyramid_num_stages): @@ -985,7 +989,7 @@ def __call__( noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact - if use_dmd: + if self.config.is_distilled: start_point_list.append(latents) for i, t in enumerate(timesteps): @@ -1028,7 +1032,7 @@ def __call__( return_dict=False, )[0] - if use_cfg_zero_star: + if self.config.is_cfg_zero_star: noise_pred_text = noise_pred positive_flat = noise_pred_text.view(batch_size, -1) negative_flat = noise_uncond.view(batch_size, -1) From 72b8811ba2237da4d52b0f6d0bef7df109fa4987 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 11:40:54 +0000 Subject: [PATCH 073/107] remove redudant --- .../pipelines/helios/pipeline_helios_pyramid.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index d3a816d8cfa1..048cd220d1c7 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -1091,13 +1091,10 @@ def __call__( total_generated_latent_frames += latents.shape[2] history_latents = torch.cat([history_latents, latents], dim=2) real_history_latents = history_latents[:, :, -total_generated_latent_frames:] - index_slice = ( - slice(None), - slice(None), - slice(-num_latent_frames_per_chunk, None), + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean ) - - current_latents = real_history_latents[index_slice].to(vae_dtype) / latents_std + latents_mean current_video = self.vae.decode(current_latents, return_dict=False)[0] if history_video is None: From 5a64c6d64d5d7c868f05d08d5cb3910e24e29ea2 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 11:57:15 +0000 Subject: [PATCH 074/107] change example scripts name --- .../{stage-1_i2v.sh => helios-base_i2v.sh} | 1 - .../{stage-1_t2v.sh => helios-base_t2v.sh} | 1 - .../{stage-1_v2v.sh => helios-base_v2v.sh} | 1 - .../{stage-3_i2v.sh => helios-distilled_i2v.sh} | 0 .../{stage-3_t2v.sh => helios-distilled_t2v.sh} | 0 .../{stage-3_v2v.sh => helios-distilled_v2v.sh} | 0 .../{stage-2_i2v.sh => helios-mid_i2v.sh} | 1 - .../{stage-2_t2v.sh => helios-mid_t2v.sh} | 1 - 0_temp_helios_test/stage-2_v2v.sh | 17 ----------------- 9 files changed, 22 deletions(-) rename 0_temp_helios_test/{stage-1_i2v.sh => helios-base_i2v.sh} (97%) rename 0_temp_helios_test/{stage-1_t2v.sh => helios-base_t2v.sh} (97%) rename 0_temp_helios_test/{stage-1_v2v.sh => helios-base_v2v.sh} (97%) rename 0_temp_helios_test/{stage-3_i2v.sh => helios-distilled_i2v.sh} (100%) rename 0_temp_helios_test/{stage-3_t2v.sh => helios-distilled_t2v.sh} (100%) rename 0_temp_helios_test/{stage-3_v2v.sh => helios-distilled_v2v.sh} (100%) rename 0_temp_helios_test/{stage-2_i2v.sh => helios-mid_i2v.sh} (97%) rename 0_temp_helios_test/{stage-2_t2v.sh => helios-mid_t2v.sh} (97%) delete mode 100644 0_temp_helios_test/stage-2_v2v.sh diff --git a/0_temp_helios_test/stage-1_i2v.sh b/0_temp_helios_test/helios-base_i2v.sh similarity index 97% rename from 0_temp_helios_test/stage-1_i2v.sh rename to 0_temp_helios_test/helios-base_i2v.sh index 664892db3c14..db650c0e577a 100644 --- a/0_temp_helios_test/stage-1_i2v.sh +++ b/0_temp_helios_test/helios-base_i2v.sh @@ -8,7 +8,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --output_folder "./output_helios/stage-1" - # --use_default_loader \ # --enable_compile \ # --use_cfg_zero_star \ # --use_zero_init \ diff --git a/0_temp_helios_test/stage-1_t2v.sh b/0_temp_helios_test/helios-base_t2v.sh similarity index 97% rename from 0_temp_helios_test/stage-1_t2v.sh rename to 0_temp_helios_test/helios-base_t2v.sh index 9932ccdbbee1..63d0ea7dd552 100644 --- a/0_temp_helios_test/stage-1_t2v.sh +++ b/0_temp_helios_test/helios-base_t2v.sh @@ -7,7 +7,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --output_folder "./output_helios/stage-1" - # --use_default_loader \ # --enable_compile \ # --use_cfg_zero_star \ # --use_zero_init \ diff --git a/0_temp_helios_test/stage-1_v2v.sh b/0_temp_helios_test/helios-base_v2v.sh similarity index 97% rename from 0_temp_helios_test/stage-1_v2v.sh rename to 0_temp_helios_test/helios-base_v2v.sh index 626403936e69..cbdd21bcb250 100644 --- a/0_temp_helios_test/stage-1_v2v.sh +++ b/0_temp_helios_test/helios-base_v2v.sh @@ -8,7 +8,6 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --output_folder "./output_helios/stage-1" - # --use_default_loader \ # --enable_compile \ # --use_cfg_zero_star \ # --use_zero_init \ diff --git a/0_temp_helios_test/stage-3_i2v.sh b/0_temp_helios_test/helios-distilled_i2v.sh similarity index 100% rename from 0_temp_helios_test/stage-3_i2v.sh rename to 0_temp_helios_test/helios-distilled_i2v.sh diff --git a/0_temp_helios_test/stage-3_t2v.sh b/0_temp_helios_test/helios-distilled_t2v.sh similarity index 100% rename from 0_temp_helios_test/stage-3_t2v.sh rename to 0_temp_helios_test/helios-distilled_t2v.sh diff --git a/0_temp_helios_test/stage-3_v2v.sh b/0_temp_helios_test/helios-distilled_v2v.sh similarity index 100% rename from 0_temp_helios_test/stage-3_v2v.sh rename to 0_temp_helios_test/helios-distilled_v2v.sh diff --git a/0_temp_helios_test/stage-2_i2v.sh b/0_temp_helios_test/helios-mid_i2v.sh similarity index 97% rename from 0_temp_helios_test/stage-2_i2v.sh rename to 0_temp_helios_test/helios-mid_i2v.sh index 53033bd38bee..79154db937fe 100644 --- a/0_temp_helios_test/stage-2_i2v.sh +++ b/0_temp_helios_test/helios-mid_i2v.sh @@ -13,5 +13,4 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ # --pyramid_num_inference_steps_list 17 17 17 \ - # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_t2v.sh b/0_temp_helios_test/helios-mid_t2v.sh similarity index 97% rename from 0_temp_helios_test/stage-2_t2v.sh rename to 0_temp_helios_test/helios-mid_t2v.sh index b0ec7201a318..7d3eedff7dd2 100644 --- a/0_temp_helios_test/stage-2_t2v.sh +++ b/0_temp_helios_test/helios-mid_t2v.sh @@ -12,5 +12,4 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ # --pyramid_num_inference_steps_list 17 17 17 \ - # --use_default_loader \ # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/stage-2_v2v.sh b/0_temp_helios_test/stage-2_v2v.sh deleted file mode 100644 index a9b51592087f..000000000000 --- a/0_temp_helios_test/stage-2_v2v.sh +++ /dev/null @@ -1,17 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Mid" \ - --transformer_path "BestWishYsh/Helios-Mid" \ - --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ - --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ - --guidance_scale 5.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 20 20 20 \ - --use_zero_init \ - --zero_steps 1 \ - --output_folder "./output_helios/stage-2" - - - # --pyramid_num_inference_steps_list 17 17 17 \ - # --use_default_loader \ - # --enable_compile \ \ No newline at end of file From 0b6a41e9b5ce47b05ed0d55619339d502ccf9160 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 11:57:30 +0000 Subject: [PATCH 075/107] change example scripts name --- 0_temp_helios_test/helios-mid_v2v.sh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 0_temp_helios_test/helios-mid_v2v.sh diff --git a/0_temp_helios_test/helios-mid_v2v.sh b/0_temp_helios_test/helios-mid_v2v.sh new file mode 100644 index 000000000000..28597c92c991 --- /dev/null +++ b/0_temp_helios_test/helios-mid_v2v.sh @@ -0,0 +1,16 @@ +CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ + --base_model_path "BestWishYsh/Helios-Mid" \ + --transformer_path "BestWishYsh/Helios-Mid" \ + --sample_type "v2v" \ + --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ + --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ + --guidance_scale 5.0 \ + --is_enable_stage2 \ + --pyramid_num_inference_steps_list 20 20 20 \ + --use_zero_init \ + --zero_steps 1 \ + --output_folder "./output_helios/stage-2" + + + # --pyramid_num_inference_steps_list 17 17 17 \ + # --enable_compile \ \ No newline at end of file From 11e1a9ad88a1a80fb51a69f674d8229836e978e1 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 12:36:41 +0000 Subject: [PATCH 076/107] correct docs --- docs/source/en/api/pipelines/helios.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 952dd77a6cf5..c6c30d8a0d36 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -28,7 +28,7 @@ The following Helios models are supported in Diffusers: -- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and UniPCMultistepScheduler. +- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler. - [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler. - [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosScheduler. From a88f60fe4f2f1a8946214ff17447655b855ae8ba Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 14:48:05 +0000 Subject: [PATCH 077/107] update example --- 0_temp_helios_test/helios-base_i2v.sh | 2 +- 0_temp_helios_test/helios-base_t2v.sh | 4 ++-- 0_temp_helios_test/helios-base_v2v.sh | 2 +- 0_temp_helios_test/helios-distilled_i2v.sh | 2 +- 0_temp_helios_test/helios-distilled_t2v.sh | 4 ++-- 0_temp_helios_test/helios-distilled_v2v.sh | 2 +- 0_temp_helios_test/helios-mid_i2v.sh | 2 +- 0_temp_helios_test/helios-mid_t2v.sh | 4 ++-- 0_temp_helios_test/helios-mid_v2v.sh | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/0_temp_helios_test/helios-base_i2v.sh b/0_temp_helios_test/helios-base_i2v.sh index db650c0e577a..44fcb203350c 100644 --- a/0_temp_helios_test/helios-base_i2v.sh +++ b/0_temp_helios_test/helios-base_i2v.sh @@ -5,7 +5,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ --guidance_scale 5.0 \ --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ - --output_folder "./output_helios/stage-1" + --output_folder "./output_helios/helios-base" # --enable_compile \ diff --git a/0_temp_helios_test/helios-base_t2v.sh b/0_temp_helios_test/helios-base_t2v.sh index 63d0ea7dd552..8626e439bfca 100644 --- a/0_temp_helios_test/helios-base_t2v.sh +++ b/0_temp_helios_test/helios-base_t2v.sh @@ -2,9 +2,9 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Base" \ --transformer_path "BestWishYsh/Helios-Base" \ --sample_type "t2v" \ - --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ --guidance_scale 5.0 \ - --output_folder "./output_helios/stage-1" + --output_folder "./output_helios/helios-base" # --enable_compile \ diff --git a/0_temp_helios_test/helios-base_v2v.sh b/0_temp_helios_test/helios-base_v2v.sh index cbdd21bcb250..710a1b45eb74 100644 --- a/0_temp_helios_test/helios-base_v2v.sh +++ b/0_temp_helios_test/helios-base_v2v.sh @@ -5,7 +5,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ --guidance_scale 5.0 \ - --output_folder "./output_helios/stage-1" + --output_folder "./output_helios/helios-base" # --enable_compile \ diff --git a/0_temp_helios_test/helios-distilled_i2v.sh b/0_temp_helios_test/helios-distilled_i2v.sh index e903a1475e2f..00c147a57f5a 100644 --- a/0_temp_helios_test/helios-distilled_i2v.sh +++ b/0_temp_helios_test/helios-distilled_i2v.sh @@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ - --output_folder "./output_helios/stage-3" + --output_folder "./output_helios/helios-distilled" # --pyramid_num_inference_steps_list 1 1 1 \ diff --git a/0_temp_helios_test/helios-distilled_t2v.sh b/0_temp_helios_test/helios-distilled_t2v.sh index 02d8ca33dbf8..bbe95b9981eb 100644 --- a/0_temp_helios_test/helios-distilled_t2v.sh +++ b/0_temp_helios_test/helios-distilled_t2v.sh @@ -2,14 +2,14 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Distilled" \ --transformer_path "BestWishYsh/Helios-Distilled" \ --sample_type "t2v" \ - --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ --num_frames 240 \ --guidance_scale 1.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ - --output_folder "./output_helios/stage-3" + --output_folder "./output_helios/helios-distilled" # --pyramid_num_inference_steps_list 1 1 1 \ diff --git a/0_temp_helios_test/helios-distilled_v2v.sh b/0_temp_helios_test/helios-distilled_v2v.sh index 24fbacf64312..f6ecadcf0c82 100644 --- a/0_temp_helios_test/helios-distilled_v2v.sh +++ b/0_temp_helios_test/helios-distilled_v2v.sh @@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ --pyramid_num_inference_steps_list 2 2 2 \ --is_enable_stage3 \ --is_amplify_first_chunk \ - --output_folder "./output_helios/stage-3" + --output_folder "./output_helios/helios-distilled" # --pyramid_num_inference_steps_list 1 1 1 \ diff --git a/0_temp_helios_test/helios-mid_i2v.sh b/0_temp_helios_test/helios-mid_i2v.sh index 79154db937fe..46de67456e8d 100644 --- a/0_temp_helios_test/helios-mid_i2v.sh +++ b/0_temp_helios_test/helios-mid_i2v.sh @@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --pyramid_num_inference_steps_list 20 20 20 \ --use_zero_init \ --zero_steps 1 \ - --output_folder "./output_helios/stage-2" + --output_folder "./output_helios/helios-mid" # --pyramid_num_inference_steps_list 17 17 17 \ diff --git a/0_temp_helios_test/helios-mid_t2v.sh b/0_temp_helios_test/helios-mid_t2v.sh index 7d3eedff7dd2..c56394776550 100644 --- a/0_temp_helios_test/helios-mid_t2v.sh +++ b/0_temp_helios_test/helios-mid_t2v.sh @@ -2,13 +2,13 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --base_model_path "BestWishYsh/Helios-Mid" \ --transformer_path "BestWishYsh/Helios-Mid" \ --sample_type "t2v" \ - --prompt "A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond." \ + --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ --guidance_scale 5.0 \ --is_enable_stage2 \ --pyramid_num_inference_steps_list 20 20 20 \ --use_zero_init \ --zero_steps 1 \ - --output_folder "./output_helios/stage-2" + --output_folder "./output_helios/helios-mid" # --pyramid_num_inference_steps_list 17 17 17 \ diff --git a/0_temp_helios_test/helios-mid_v2v.sh b/0_temp_helios_test/helios-mid_v2v.sh index 28597c92c991..c4eadade4370 100644 --- a/0_temp_helios_test/helios-mid_v2v.sh +++ b/0_temp_helios_test/helios-mid_v2v.sh @@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ --pyramid_num_inference_steps_list 20 20 20 \ --use_zero_init \ --zero_steps 1 \ - --output_folder "./output_helios/stage-2" + --output_folder "./output_helios/helios-mid" # --pyramid_num_inference_steps_list 17 17 17 \ From c4fc87b1f396e1b3644e14b9c58062ebfbffc127 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Sun, 1 Mar 2026 14:50:30 +0000 Subject: [PATCH 078/107] update docs --- docs/source/en/api/pipelines/helios.md | 68 +++++++++----------------- 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index c6c30d8a0d36..0af889a947f7 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -70,11 +70,12 @@ pipeline.enable_group_offload( ) prompt = """ -A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various -elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. -The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but -emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and -exploration. Medium shot focusing on the train window and the rushing scenery beyond. +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. """ negative_prompt = """ Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, @@ -122,11 +123,12 @@ pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False) pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) prompt = """ -A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various -elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. -The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but -emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and -exploration. Medium shot focusing on the train window and the rushing scenery beyond. +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. """ negative_prompt = """ Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, @@ -177,11 +179,12 @@ misshapen limbs, fused fingers, still picture, messy background, three legs, man # For Text-to-Video prompt = """ -A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various -elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. -The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but -emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and -exploration. Medium shot focusing on the train window and the rushing scenery beyond. +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. """ output = pipeline( @@ -268,22 +271,19 @@ misshapen limbs, fused fingers, still picture, messy background, three legs, man # For Text-to-Video prompt = """ -A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various -elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. -The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but -emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and -exploration. Medium shot focusing on the train window and the rushing scenery beyond. +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. """ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[20, 20, 20], - use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, ).frames[0] @@ -305,11 +305,7 @@ output = pipeline( negative_prompt=negative_prompt, image=load_image(image_path).resize((640, 384)), num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[20, 20, 20], - use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, ).frames[0] @@ -330,11 +326,7 @@ output = pipeline( negative_prompt=negative_prompt, video=load_video(video_path), num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[20, 20, 20], - use_cfg_zero_star=True, use_zero_init=True, zero_steps=1, ).frames[0] @@ -385,11 +377,7 @@ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[2, 2, 2], - use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] @@ -411,11 +399,7 @@ output = pipeline( negative_prompt=negative_prompt, image=load_image(image_path).resize((640, 384)), num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[2, 2, 2], - use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] @@ -436,11 +420,7 @@ output = pipeline( negative_prompt=negative_prompt, video=load_video(video_path), num_frames=99, - use_dynamic_shifting=True, - is_enable_stage2=True, - pyramid_num_stages=3, pyramid_num_inference_steps_list=[2, 2, 2], - use_dmd=True, guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] From 7236203c4838b406153843f541b350fb899b1a2a Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:19:06 +0800 Subject: [PATCH 079/107] Update tests/models/transformers/test_models_transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/models/transformers/test_models_transformer_helios.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index 73fe7a99b86d..631b13d41018 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -90,7 +90,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - indices_hidden_states = torch.ones((2,)).to(torch_device) + indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device) indices_latents_history_short = torch.ones((num_frames - 1,)).to(torch_device) indices_latents_history_mid = torch.ones((num_frames - 1,)).to(torch_device) indices_latents_history_long = torch.ones(((num_frames - 1) * 4,)).to(torch_device) From f460a89ae26d94ffdc6d37ee2256e7372b271726 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:35:20 +0800 Subject: [PATCH 080/107] Update tests/models/transformers/test_models_transformer_helios.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/models/transformers/test_models_transformer_helios.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index 631b13d41018..eb1e85e74562 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -91,9 +91,9 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device) - indices_latents_history_short = torch.ones((num_frames - 1,)).to(torch_device) - indices_latents_history_mid = torch.ones((num_frames - 1,)).to(torch_device) - indices_latents_history_long = torch.ones(((num_frames - 1) * 4,)).to(torch_device) + indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device) + indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device) + indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device) latents_history_short = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) latents_history_mid = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) latents_history_long = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to( From 6cccfa73fcbc785af987e18a2e61b2e987499ae0 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 03:12:37 +0000 Subject: [PATCH 081/107] separate HeliosDMDScheduler --- 0_temp_helios_test/infer_helios.py | 5 +- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/helios.md | 2 +- docs/source/en/api/schedulers/helios_dmd.md | 20 + src/diffusers/__init__.py | 2 + .../helios/pipeline_helios_pyramid.py | 26 +- src/diffusers/schedulers/__init__.py | 2 + src/diffusers/schedulers/scheduling_helios.py | 84 +--- .../schedulers/scheduling_helios_dmd.py | 389 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + 10 files changed, 455 insertions(+), 92 deletions(-) create mode 100644 docs/source/en/api/schedulers/helios_dmd.md create mode 100644 src/diffusers/schedulers/scheduling_helios_dmd.py diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index bb772fd72fed..5d8f24a2009a 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -196,7 +196,10 @@ def main(): torch_dtype=args.weight_dtype, use_default_loader=args.use_default_loader, ) - transformer.set_attention_backend("_flash_3_hub") + try: + transformer.set_attention_backend("_flash_3_hub") + except: + transformer.set_attention_backend("flash_hub") vae = AutoencoderKLWan.from_pretrained( args.base_model_path, diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d75218af0da8..ea06f35a0343 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -750,6 +750,8 @@ title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler + - local: api/schedulers/helios_dmd + title: HeliosDMDScheduler - local: api/schedulers/helios title: HeliosScheduler - local: api/schedulers/heun diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 0af889a947f7..24ffae6e3939 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -30,7 +30,7 @@ The following Helios models are supported in Diffusers: - [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler. - [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler. -- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosScheduler. +- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler. > [!TIP] > Click on the Helios models in the right sidebar for more examples of video generation. diff --git a/docs/source/en/api/schedulers/helios_dmd.md b/docs/source/en/api/schedulers/helios_dmd.md new file mode 100644 index 000000000000..4f075e8a7dfc --- /dev/null +++ b/docs/source/en/api/schedulers/helios_dmd.md @@ -0,0 +1,20 @@ + + +# HeliosDMDScheduler + +`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers). + +## HeliosDMDScheduler +[[autodoc]] HeliosDMDScheduler + +scheduling_helios_dmd diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ea6e06471379..1458164191df 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -360,6 +360,7 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", + "HeliosDMDScheduler", "HeliosScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", @@ -1127,6 +1128,7 @@ FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, + HeliosDMDScheduler, HeliosScheduler, HeunDiscreteScheduler, IPNDMScheduler, diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 048cd220d1c7..5320a6393bf3 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput from ...loaders import HeliosLoraLoaderMixin from ...models import AutoencoderKLWan, HeliosTransformer3DModel -from ...schedulers import HeliosScheduler +from ...schedulers import HeliosDMDScheduler, HeliosScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -133,7 +133,7 @@ class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. transformer ([`HeliosTransformer3DModel`]): Conditional Transformer to denoise the input latents. - scheduler ([`HeliosScheduler`]): + scheduler ([`HeliosScheduler`, `HeliosDMDScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. @@ -148,7 +148,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKLWan, - scheduler: HeliosScheduler, + scheduler: HeliosScheduler | HeliosDMDScheduler, transformer: HeliosTransformer3DModel, is_cfg_zero_star: bool = False, is_distilled: bool = False, @@ -1050,17 +1050,27 @@ def __call__( else: noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + extra_kwargs = ( + { + "cur_sampling_step": i, + "dmd_noisy_tensor": start_point_list[stage_idx] + if start_point_list is not None + else None, + "dmd_sigmas": self.scheduler.sigmas, + "dmd_timesteps": self.scheduler.timesteps, + "all_timesteps": timesteps, + } + if self.config.is_distilled + else {} + ) + latents = self.scheduler.step( noise_pred, t, latents, generator=generator, return_dict=False, - cur_sampling_step=i, - dmd_noisy_tensor=start_point_list[stage_idx] if start_point_list is not None else None, - dmd_sigmas=self.scheduler.sigmas, - dmd_timesteps=self.scheduler.timesteps, - all_timesteps=timesteps, + **extra_kwargs, )[0] if callback_on_step_end is not None: diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 56c58dbeb069..c7101d1b0401 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -62,6 +62,7 @@ _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] _import_structure["scheduling_helios"] = ["HeliosScheduler"] + _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -166,6 +167,7 @@ from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler from .scheduling_helios import HeliosScheduler + from .scheduling_helios_dmd import HeliosDMDScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_helios.py b/src/diffusers/schedulers/scheduling_helios.py index 1d3a16ed12f2..ed35245c9db3 100644 --- a/src/diffusers/schedulers/scheduling_helios.py +++ b/src/diffusers/schedulers/scheduling_helios.py @@ -54,9 +54,9 @@ def __init__( disable_corrector: list[int] = [], solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, - scheduler_type: str = "unipc", # ["euler", "unipc", "dmd"] + scheduler_type: str = "unipc", # ["euler", "unipc"] use_dynamic_shifting: bool = False, - time_shift_type: Literal["exponential", "linear"] = "linear", + time_shift_type: Literal["exponential", "linear"] = "exponential", ): self.timestep_ratios = {} # The timestep ratio for each stage self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) @@ -826,67 +826,6 @@ def step_unipc( this_order=self.this_order, ) - # ---------------------------------- For DMD ---------------------------------- - def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): - sigmas = sigmas.to(noise.device) - timesteps = timesteps.to(noise.device) - timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) - sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) - sample = (1 - sigma) * original_samples + sigma * noise - return sample.type_as(noise) - - def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): - # use higher precision for calculations - original_dtype = flow_pred.dtype - device = flow_pred.device - flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) - - timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) - sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) - x0_pred = xt - sigma_t * flow_pred - return x0_pred.to(original_dtype) - - def step_dmd( - self, - model_output: torch.FloatTensor, - timestep: float | torch.FloatTensor = None, - sample: torch.FloatTensor = None, - generator: torch.Generator | None = None, - return_dict: bool = True, - cur_sampling_step: int = 0, - dmd_noisy_tensor: torch.FloatTensor | None = None, - dmd_sigmas: torch.FloatTensor | None = None, - dmd_timesteps: torch.FloatTensor | None = None, - all_timesteps: torch.FloatTensor | None = None, - ): - pred_image_or_video = self.convert_flow_pred_to_x0( - flow_pred=model_output, - xt=sample, - timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), - sigmas=dmd_sigmas, - timesteps=dmd_timesteps, - ) - if cur_sampling_step < len(all_timesteps) - 1: - prev_sample = self.add_noise( - pred_image_or_video, - dmd_noisy_tensor, - torch.full( - (model_output.shape[0],), - all_timesteps[cur_sampling_step + 1], - dtype=torch.long, - device=model_output.device, - ), - sigmas=dmd_sigmas, - timesteps=dmd_timesteps, - ) - else: - prev_sample = pred_image_or_video - - if not return_dict: - return (prev_sample,) - - return HeliosSchedulerOutput(prev_sample=prev_sample) - # ---------------------------------- Merge ---------------------------------- def step( self, @@ -895,12 +834,6 @@ def step( sample: torch.FloatTensor = None, generator: torch.Generator | None = None, return_dict: bool = True, - # For DMD - cur_sampling_step: int = 0, - dmd_noisy_tensor: torch.FloatTensor | None = None, - dmd_sigmas: torch.FloatTensor | None = None, - dmd_timesteps: torch.FloatTensor | None = None, - all_timesteps: torch.FloatTensor | None = None, ) -> HeliosSchedulerOutput | tuple: if self.config.scheduler_type == "euler": return self.step_euler( @@ -917,19 +850,6 @@ def step( sample=sample, return_dict=return_dict, ) - elif self.config.scheduler_type == "dmd": - return self.step_dmd( - model_output=model_output, - timestep=timestep, - sample=sample, - generator=generator, - return_dict=return_dict, - cur_sampling_step=cur_sampling_step, - dmd_noisy_tensor=dmd_noisy_tensor, - dmd_sigmas=dmd_sigmas, - dmd_timesteps=dmd_timesteps, - all_timesteps=all_timesteps, - ) else: raise NotImplementedError diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py new file mode 100644 index 000000000000..e13545dc3c6e --- /dev/null +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -0,0 +1,389 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput + + +@dataclass +class HeliosDMDSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None + + +class HeliosDMDScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, # Following Stable diffusion 3, + stages: int = 3, + stage_range: list = [0, 1 / 3, 2 / 3, 1], + gamma: float = 1 / 3, + thresholding: bool = False, + prediction_type: str = "flow_prediction", + solver_order: int = 2, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + use_flow_sigmas: bool = True, + scheduler_type: str = "dmd", + use_dynamic_shifting: bool = False, + time_shift_type: Literal["exponential", "linear"] = "linear", + ): + self.timestep_ratios = {} # The timestep ratio for each stage + self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) + self.sigmas_per_stage = {} # always uniform [1000, 0] + self.start_sigmas = {} # for start point / upsample renoise + self.end_sigmas = {} # for end point + self.ori_start_sigmas = {} + + # self.init_sigmas() + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def init_sigmas(self): + """ + initialize the global timesteps and sigmas + """ + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def init_sigmas_for_each_stage(self): + """ + Init the timesteps for each stage + """ + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + # Init the start and end point of each stage + for i_s in range(stages): + # To decide the start and ends point + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma + # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + # Determine the ratio of each stage according to flow length + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + end_ratio = 0.9999999999999999 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + # Determine the timesteps and sigmas for each stage + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] + timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) + timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) + ) + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int | None = None, + device: str | torch.device = None, + sigmas: bool | None = None, + mu: bool | None = None, + is_amplify_first_chunk: bool = False, + ): + """ + Setting the timesteps and sigmas for each stage + """ + if self.config.scheduler_type == "dmd": + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( + np.float32 + ) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: + stage_timesteps = self.timesteps_per_stage[stage_index] + timesteps = np.linspace( + stage_timesteps[0].item(), + stage_timesteps[-1].item(), + num_inference_steps, + ) + + stage_sigmas = self.sigmas_per_stage[stage_index] + ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps) + sigmas = torch.from_numpy(ratios) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) + + self._step_index = None + self.reset_scheduler_history() + + if self.config.scheduler_type == "dmd": + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + + if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + """ + Apply time shifting to the sigmas. + + Args: + mu (`float`): + The mu parameter for the time shift. + sigma (`float`): + The sigma parameter for the time shift. + t (`torch.Tensor`): + The input timesteps. + + Returns: + `torch.Tensor`: + The time-shifted timesteps. + """ + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + # ---------------------------------- For DMD ---------------------------------- + def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): + # use higher precision for calculations + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + def step_dmd( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ): + pred_image_or_video = self.convert_flow_pred_to_x0( + flow_pred=model_output, + xt=sample, + timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if cur_sampling_step < len(all_timesteps) - 1: + prev_sample = self.add_noise( + pred_image_or_video, + dmd_noisy_tensor, + torch.full( + (model_output.shape[0],), + all_timesteps[cur_sampling_step + 1], + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + else: + prev_sample = pred_image_or_video + + if not return_dict: + return (prev_sample,) + + return HeliosDMDSchedulerOutput(prev_sample=prev_sample) + + # ---------------------------------- Merge ---------------------------------- + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ) -> HeliosDMDSchedulerOutput | tuple: + if self.config.scheduler_type == "dmd": + return self.step_dmd( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + return_dict=return_dict, + cur_sampling_step=cur_sampling_step, + dmd_noisy_tensor=dmd_noisy_tensor, + dmd_sigmas=dmd_sigmas, + dmd_timesteps=dmd_timesteps, + all_timesteps=all_timesteps, + ) + else: + raise NotImplementedError + + def reset_scheduler_history(self): + self.model_outputs = [None] * self.config.solver_order + self.timestep_list = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.disable_corrector = self.config.disable_corrector + self.solver_p = self.config.solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6c12a4f04d2f..3a4aecd24f90 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2758,6 +2758,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HeliosDMDScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HeliosScheduler(metaclass=DummyObject): _backends = ["torch"] From 3bec7cfeb0c7e604448fa774aaf906c67a14f4c3 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 03:47:09 +0000 Subject: [PATCH 082/107] fix numerical stability issue: --- src/diffusers/pipelines/helios/pipeline_helios_pyramid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 5320a6393bf3..4c47b89b5e7f 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -473,6 +473,7 @@ def sample_block_noise(self, batch_size, channel, num_frames, height, width, pat block_size = ph * pw cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma + cov += torch.eye(block_size) * 1e-6 dist = torch.distributions.MultivariateNormal( torch.zeros(block_size, device=cov.device), covariance_matrix=cov ) From 2800ac5f65cdf12a4f6ea9dacddbfe81c1ebbee3 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:15:07 +0800 Subject: [PATCH 083/107] Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/schedulers/scheduling_helios_dmd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index e13545dc3c6e..e81bde6d8b67 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -53,7 +53,6 @@ def __init__( disable_corrector: list[int] = [], solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, - scheduler_type: str = "dmd", use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential", "linear"] = "linear", ): From e2c5ed372b85b7743bd2301ee3072a6467ce6c30 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:15:20 +0800 Subject: [PATCH 084/107] Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/schedulers/scheduling_helios_dmd.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index e81bde6d8b67..f310d240f0cd 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -48,10 +48,6 @@ def __init__( prediction_type: str = "flow_prediction", solver_order: int = 2, predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: list[int] = [], - solver_p: SchedulerMixin = None, use_flow_sigmas: bool = True, use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential", "linear"] = "linear", From 0f2c8281bfd76198ae55b161cfbd4a0fe2ed3051 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:17:34 +0800 Subject: [PATCH 085/107] Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/schedulers/scheduling_helios_dmd.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index f310d240f0cd..e65cdedbdb77 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -198,11 +198,10 @@ def set_timesteps( """ Setting the timesteps and sigmas for each stage """ - if self.config.scheduler_type == "dmd": - if is_amplify_first_chunk: - num_inference_steps = num_inference_steps * 2 + 1 - else: - num_inference_steps = num_inference_steps + 1 + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 self.num_inference_steps = num_inference_steps self.init_sigmas() From 9601ee1112a410fa9b00a0480c2d31c997d1c398 Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:18:07 +0800 Subject: [PATCH 086/107] Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/schedulers/scheduling_helios_dmd.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index e65cdedbdb77..025b22ab4447 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -234,9 +234,8 @@ def set_timesteps( self._step_index = None self.reset_scheduler_history() - if self.config.scheduler_type == "dmd": - self.timesteps = self.timesteps[:-1] - self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) if self.config.use_dynamic_shifting: assert self.config.shift == 1.0 From ae2a300b4cbb4e531b96ad13618af619a8240d3a Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:21:09 +0800 Subject: [PATCH 087/107] Update src/diffusers/schedulers/scheduling_helios_dmd.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/schedulers/scheduling_helios_dmd.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index 025b22ab4447..47c07884c34f 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -369,12 +369,6 @@ def step( raise NotImplementedError def reset_scheduler_history(self): - self.model_outputs = [None] * self.config.solver_order - self.timestep_list = [None] * self.config.solver_order - self.lower_order_nums = 0 - self.disable_corrector = self.config.disable_corrector - self.solver_p = self.config.solver_p - self.last_sample = None self._step_index = None self._begin_index = None From d9515b947bf4dabf68b39f03e8d82e65a6abfd90 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 04:27:21 +0000 Subject: [PATCH 088/107] remove redudant --- .../schedulers/scheduling_helios_dmd.py | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index 47c07884c34f..86260849d48e 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -44,10 +44,7 @@ def __init__( stages: int = 3, stage_range: list = [0, 1 / 3, 2 / 3, 1], gamma: float = 1 / 3, - thresholding: bool = False, prediction_type: str = "flow_prediction", - solver_order: int = 2, - predict_x0: bool = True, use_flow_sigmas: bool = True, use_dynamic_shifting: bool = False, time_shift_type: Literal["exponential", "linear"] = "linear", @@ -65,18 +62,6 @@ def __init__( self.sigma_max = self.sigmas[0].item() self.gamma = gamma - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh2") - else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p self.last_sample = None self._step_index = None self._begin_index = None @@ -352,21 +337,18 @@ def step( dmd_timesteps: torch.FloatTensor | None = None, all_timesteps: torch.FloatTensor | None = None, ) -> HeliosDMDSchedulerOutput | tuple: - if self.config.scheduler_type == "dmd": - return self.step_dmd( - model_output=model_output, - timestep=timestep, - sample=sample, - generator=generator, - return_dict=return_dict, - cur_sampling_step=cur_sampling_step, - dmd_noisy_tensor=dmd_noisy_tensor, - dmd_sigmas=dmd_sigmas, - dmd_timesteps=dmd_timesteps, - all_timesteps=all_timesteps, - ) - else: - raise NotImplementedError + return self.step_dmd( + model_output=model_output, + timestep=timestep, + sample=sample, + generator=generator, + return_dict=return_dict, + cur_sampling_step=cur_sampling_step, + dmd_noisy_tensor=dmd_noisy_tensor, + dmd_sigmas=dmd_sigmas, + dmd_timesteps=dmd_timesteps, + all_timesteps=all_timesteps, + ) def reset_scheduler_history(self): self._step_index = None From 6c167c0c10221b0c01cebaa8edca6ffaa19fc39c Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 04:29:35 +0000 Subject: [PATCH 089/107] small refactor --- .../schedulers/scheduling_helios_dmd.py | 31 ++----------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_helios_dmd.py b/src/diffusers/schedulers/scheduling_helios_dmd.py index 86260849d48e..1f4afa0e3128 100644 --- a/src/diffusers/schedulers/scheduling_helios_dmd.py +++ b/src/diffusers/schedulers/scheduling_helios_dmd.py @@ -282,7 +282,7 @@ def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): x0_pred = xt - sigma_t * flow_pred return x0_pred.to(original_dtype) - def step_dmd( + def step( self, model_output: torch.FloatTensor, timestep: float | torch.FloatTensor = None, @@ -294,7 +294,7 @@ def step_dmd( dmd_sigmas: torch.FloatTensor | None = None, dmd_timesteps: torch.FloatTensor | None = None, all_timesteps: torch.FloatTensor | None = None, - ): + ) -> HeliosDMDSchedulerOutput | tuple: pred_image_or_video = self.convert_flow_pred_to_x0( flow_pred=model_output, xt=sample, @@ -323,33 +323,6 @@ def step_dmd( return HeliosDMDSchedulerOutput(prev_sample=prev_sample) - # ---------------------------------- Merge ---------------------------------- - def step( - self, - model_output: torch.FloatTensor, - timestep: float | torch.FloatTensor = None, - sample: torch.FloatTensor = None, - generator: torch.Generator | None = None, - return_dict: bool = True, - cur_sampling_step: int = 0, - dmd_noisy_tensor: torch.FloatTensor | None = None, - dmd_sigmas: torch.FloatTensor | None = None, - dmd_timesteps: torch.FloatTensor | None = None, - all_timesteps: torch.FloatTensor | None = None, - ) -> HeliosDMDSchedulerOutput | tuple: - return self.step_dmd( - model_output=model_output, - timestep=timestep, - sample=sample, - generator=generator, - return_dict=return_dict, - cur_sampling_step=cur_sampling_step, - dmd_noisy_tensor=dmd_noisy_tensor, - dmd_sigmas=dmd_sigmas, - dmd_timesteps=dmd_timesteps, - all_timesteps=all_timesteps, - ) - def reset_scheduler_history(self): self._step_index = None self._begin_index = None From 256544bed7d594313e04a032a0beff56c68b045d Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 08:31:58 +0000 Subject: [PATCH 090/107] remove use_interpolate_prompt logits --- 0_temp_helios_test/infer_helios.py | 8 -- .../pipelines/helios/pipeline_helios.py | 85 +------------------ .../helios/pipeline_helios_pyramid.py | 85 +------------------ 3 files changed, 6 insertions(+), 172 deletions(-) diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py index 5d8f24a2009a..25432678ce14 100644 --- a/0_temp_helios_test/infer_helios.py +++ b/0_temp_helios_test/infer_helios.py @@ -268,10 +268,6 @@ def main(): image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, # t2v video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, ).frames[0] else: output = pipe( @@ -298,10 +294,6 @@ def main(): image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, # t2v video=load_video(video_path) if video_path is not None else None, - # interpolate_prompt - use_interpolate_prompt=args.use_interpolate_prompt, - interpolation_steps=args.interpolation_steps, - interpolate_time_list=interpolate_time_list, ).frames[0] if not args.enable_parallelism or rank == 0: diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index c11a11b50852..f0fbd0b90b21 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -13,7 +13,6 @@ # limitations under the License. import html -from itertools import accumulate from typing import Any, Callable import numpy as np @@ -286,10 +285,6 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, image=None, video=None, - use_interpolate_prompt=False, - num_videos_per_prompt=None, - interpolate_time_list=None, - interpolation_steps=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -325,16 +320,6 @@ def check_inputs( if image is not None and video is not None: raise ValueError("image and video cannot be provided simultaneously") - if use_interpolate_prompt: - assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" - assert isinstance(prompt, list), "prompt must be a list" - assert len(prompt) == len(interpolate_time_list), ( - f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" - ) - assert min(interpolate_time_list) > interpolation_steps, ( - f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" - ) - def prepare_latents( self, batch_size: int, @@ -433,20 +418,6 @@ def prepare_video_latents( latents = torch.cat(latents_chunks, dim=2) return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) - def interpolate_prompt_embeds( - self, - prompt_embeds_1: torch.Tensor, - prompt_embeds_2: torch.Tensor, - interpolation_steps: int = 3, - ): - x = torch.lerp( - prompt_embeds_1, - prompt_embeds_2, - torch.linspace(0, 1, steps=interpolation_steps).unsqueeze(1).unsqueeze(2).to(prompt_embeds_1), - ) - interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) - return interpolated_prompt_embeds - @property def guidance_scale(self): return self._guidance_scale @@ -507,10 +478,6 @@ def __call__( add_noise_to_video_latents: bool = True, video_noise_sigma_min: float = 0.111, video_noise_sigma_max: float = 0.135, - # ------------ Interactive ------------ - use_interpolate_prompt: bool = False, - interpolate_time_list: list = [7, 7, 7], - interpolation_steps: int = 3, # ------------ Stage 1 ------------ history_sizes: list = [16, 2, 1], num_latent_frames_per_chunk: int = 9, @@ -608,10 +575,6 @@ def __call__( callback_on_step_end_tensor_inputs, image, video, - use_interpolate_prompt, - num_videos_per_prompt, - interpolate_time_list, - interpolation_steps, ) num_frames = max(num_frames, 1) @@ -625,7 +588,7 @@ def __call__( vae_dtype = self.vae.dtype # 2. Define call parameters - if use_interpolate_prompt or (prompt is not None and isinstance(prompt, str)): + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -633,12 +596,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - if use_interpolate_prompt: - interpolate_interval_idx = None - interpolate_embeds = None - interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - - all_prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, @@ -650,10 +608,8 @@ def __call__( ) transformer_dtype = self.transformer.dtype - all_prompt_embeds = all_prompt_embeds.to(transformer_dtype) + prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: - if use_interpolate_prompt: - negative_prompt_embeds = negative_prompt_embeds[0].unsqueeze(0) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare image or video @@ -790,11 +746,6 @@ def __call__( indices_latents_history_long = indices_latents_history_long.unsqueeze(0) # 6. Denoising loop - if use_interpolate_prompt: - if num_latent_chunk < max(interpolate_cumulative_list): - num_latent_chunk = sum(interpolate_cumulative_list) - print(f"Update num_latent_chunk to: {num_latent_chunk}") - patch_size = self.transformer.config.patch_size image_seq_len = ( num_latent_frames_per_chunk @@ -812,36 +763,6 @@ def __call__( ) for k in range(num_latent_chunk): - if use_interpolate_prompt: - assert num_latent_chunk >= max(interpolate_cumulative_list) - - current_interval_idx = 0 - for idx, cumulative_val in enumerate(interpolate_cumulative_list): - if k < cumulative_val: - current_interval_idx = idx - break - - if current_interval_idx == 0: - prompt_embeds = all_prompt_embeds[0].unsqueeze(0) - else: - interval_start = interpolate_cumulative_list[current_interval_idx - 1] - position_in_interval = k - interval_start - - if position_in_interval < interpolation_steps: - if interpolate_embeds is None or interpolate_interval_idx != current_interval_idx: - interpolate_embeds = self.interpolate_prompt_embeds( - prompt_embeds_1=all_prompt_embeds[current_interval_idx - 1].unsqueeze(0), - prompt_embeds_2=all_prompt_embeds[current_interval_idx].unsqueeze(0), - interpolation_steps=interpolation_steps, - ) - interpolate_interval_idx = current_interval_idx - - prompt_embeds = interpolate_embeds[position_in_interval] - else: - prompt_embeds = all_prompt_embeds[current_interval_idx].unsqueeze(0) - else: - prompt_embeds = all_prompt_embeds - is_first_chunk = k == 0 is_second_chunk = k == 1 if keep_first_frame: diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 4c47b89b5e7f..4773f1befec5 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -14,7 +14,6 @@ import html import math -from itertools import accumulate from typing import Any, Callable import regex as re @@ -302,10 +301,6 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, image=None, video=None, - use_interpolate_prompt=False, - num_videos_per_prompt=None, - interpolate_time_list=None, - interpolation_steps=None, guidance_scale=None, ): if height % 16 != 0 or width % 16 != 0: @@ -342,16 +337,6 @@ def check_inputs( if image is not None and video is not None: raise ValueError("image and video cannot be provided simultaneously") - if use_interpolate_prompt: - assert num_videos_per_prompt == 1, f"num_videos_per_prompt must be 1, got {num_videos_per_prompt}" - assert isinstance(prompt, list), "prompt must be a list" - assert len(prompt) == len(interpolate_time_list), ( - f"Length mismatch: {len(prompt)} vs {len(interpolate_time_list)}" - ) - assert min(interpolate_time_list) > interpolation_steps, ( - f"Minimum value {min(interpolate_time_list)} must be greater than {interpolation_steps}" - ) - if guidance_scale > 1.0 and self.config.is_distilled: logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") @@ -453,20 +438,6 @@ def prepare_video_latents( latents = torch.cat(latents_chunks, dim=2) return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) - def interpolate_prompt_embeds( - self, - prompt_embeds_1: torch.Tensor, - prompt_embeds_2: torch.Tensor, - interpolation_steps: int = 3, - ): - x = torch.lerp( - prompt_embeds_1, - prompt_embeds_2, - torch.linspace(0, 1, steps=interpolation_steps).unsqueeze(1).unsqueeze(2).to(prompt_embeds_1), - ) - interpolated_prompt_embeds = list(x.chunk(interpolation_steps, dim=0)) - return interpolated_prompt_embeds - def sample_block_noise(self, batch_size, channel, num_frames, height, width, patch_size=(1, 2, 2)): gamma = self.scheduler.config.gamma _, ph, pw = patch_size @@ -543,10 +514,6 @@ def __call__( add_noise_to_video_latents: bool = True, video_noise_sigma_min: float = 0.111, video_noise_sigma_max: float = 0.135, - # ------------ Interactive ------------ - use_interpolate_prompt: bool = False, - interpolate_time_list: list = [7, 7, 7], - interpolation_steps: int = 3, # ------------ Stage 1 ------------ history_sizes: list = [16, 2, 1], num_latent_frames_per_chunk: int = 9, @@ -649,10 +616,6 @@ def __call__( callback_on_step_end_tensor_inputs, image, video, - use_interpolate_prompt, - num_videos_per_prompt, - interpolate_time_list, - interpolation_steps, guidance_scale, ) @@ -667,7 +630,7 @@ def __call__( vae_dtype = self.vae.dtype # 2. Define call parameters - if use_interpolate_prompt or (prompt is not None and isinstance(prompt, str)): + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) @@ -675,12 +638,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - if use_interpolate_prompt: - interpolate_interval_idx = None - interpolate_embeds = None - interpolate_cumulative_list = list(accumulate(interpolate_time_list)) - - all_prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, @@ -692,10 +650,8 @@ def __call__( ) transformer_dtype = self.transformer.dtype - all_prompt_embeds = all_prompt_embeds.to(transformer_dtype) + prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: - if use_interpolate_prompt: - negative_prompt_embeds = negative_prompt_embeds[0].unsqueeze(0) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare image or video @@ -832,42 +788,7 @@ def __call__( indices_latents_history_long = indices_latents_history_long.unsqueeze(0) # 6. Denoising loop - if use_interpolate_prompt: - if num_latent_chunk < max(interpolate_cumulative_list): - num_latent_chunk = sum(interpolate_cumulative_list) - print(f"Update num_latent_chunk to: {num_latent_chunk}") - for k in range(num_latent_chunk): - if use_interpolate_prompt: - assert num_latent_chunk >= max(interpolate_cumulative_list) - - current_interval_idx = 0 - for idx, cumulative_val in enumerate(interpolate_cumulative_list): - if k < cumulative_val: - current_interval_idx = idx - break - - if current_interval_idx == 0: - prompt_embeds = all_prompt_embeds[0].unsqueeze(0) - else: - interval_start = interpolate_cumulative_list[current_interval_idx - 1] - position_in_interval = k - interval_start - - if position_in_interval < interpolation_steps: - if interpolate_embeds is None or interpolate_interval_idx != current_interval_idx: - interpolate_embeds = self.interpolate_prompt_embeds( - prompt_embeds_1=all_prompt_embeds[current_interval_idx - 1].unsqueeze(0), - prompt_embeds_2=all_prompt_embeds[current_interval_idx].unsqueeze(0), - interpolation_steps=interpolation_steps, - ) - interpolate_interval_idx = current_interval_idx - - prompt_embeds = interpolate_embeds[position_in_interval] - else: - prompt_embeds = all_prompt_embeds[current_interval_idx].unsqueeze(0) - else: - prompt_embeds = all_prompt_embeds - is_first_chunk = k == 0 is_second_chunk = k == 1 if keep_first_frame: From 724d6b763eeacd59e20b248fe313cdc8729c5dc2 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Mon, 2 Mar 2026 16:42:49 +0000 Subject: [PATCH 091/107] simplified model test --- .../test_models_transformer_helios.py | 138 ++++++------------ 1 file changed, 41 insertions(+), 97 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index eb1e85e74562..48ed52c75f8c 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest import torch @@ -20,11 +22,8 @@ from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( AttentionTesterMixin, - BaseModelTesterConfig, - BitsAndBytesTesterMixin, MemoryTesterMixin, ModelTesterMixin, - TorchAoTesterMixin, TorchCompileTesterMixin, TrainingTesterMixin, ) @@ -33,52 +32,13 @@ enable_full_determinism() -class HeliosTransformer3DTesterConfig(BaseModelTesterConfig): - @property - def model_class(self): - return HeliosTransformer3DModel - - @property - def pretrained_model_name_or_path(self): - return "hf-internal-testing/tiny-helios-base-transformer" +class HeliosTransformer3DTesterConfig(ModelTesterMixin, unittest.TestCase): + model_class = HeliosTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True @property - def output_shape(self) -> tuple[int, ...]: - return (4, 2, 16, 16) - - @property - def input_shape(self) -> tuple[int, ...]: - return (4, 2, 16, 16) - - @property - def main_input_name(self) -> str: - return "hidden_states" - - @property - def generator(self): - return torch.Generator("cpu").manual_seed(0) - - def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]: - return { - "patch_size": (1, 2, 2), - "num_attention_heads": 2, - "attention_head_dim": 12, - "in_channels": 4, - "out_channels": 4, - "text_dim": 16, - "freq_dim": 256, - "ffn_dim": 32, - "num_layers": 2, - "cross_attn_norm": True, - "qk_norm": "rms_norm_across_heads", - "rope_dim": (4, 4, 4), - "has_multi_term_memory_patch": True, - "guidance_cross_attn": True, - "zero_history_timestep": True, - "is_amplify_history": False, - } - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + def dummy_input(self): batch_size = 1 num_channels = 4 num_frames = 2 @@ -113,6 +73,40 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: "latents_history_long": latents_history_long, } + @property + def input_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def output_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HeliosTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Helios Transformer 3D.""" @@ -142,53 +136,3 @@ class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, Attentio class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin): """Torch compile tests for Helios Transformer 3D.""" - - -class TestHeliosTransformer3DBitsAndBytes(HeliosTransformer3DTesterConfig, BitsAndBytesTesterMixin): - """BitsAndBytes quantization tests for Helios Transformer 3D.""" - - @property - def torch_dtype(self): - return torch.float16 - - -class TestHeliosTransformer3DTorchAo(HeliosTransformer3DTesterConfig, TorchAoTesterMixin): - """TorchAO quantization tests for Helios Transformer 3D.""" - - @property - def torch_dtype(self): - return torch.bfloat16 - - -# class TestHeliosTransformer3DGGUF(HeliosTransformer3DTesterConfig, GGUFTesterMixin): -# """GGUF quantization tests for Helios Transformer 3D.""" - -# @property -# def gguf_filename(self): -# return "" - -# @property -# def torch_dtype(self): -# return torch.bfloat16 - -# def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): -# return super()._create_quantized_model( -# config_kwargs, config="BestWishYsh/Helios-Base", subfolder="transformer", **extra_kwargs -# ) - - -# class TestHeliosTransformer3DGGUFCompile(HeliosTransformer3DTesterConfig, GGUFCompileTesterMixin): -# """GGUF + compile tests for Helios Transformer 3D.""" - -# @property -# def gguf_filename(self): -# return "" - -# @property -# def torch_dtype(self): -# return torch.bfloat16 - -# def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): -# return super()._create_quantized_model( -# config_kwargs, config="BestWishYsh/Helios-Base", subfolder="transformer", **extra_kwargs -# ) From 604247c1fdd3fd13eca8262ef61fed23e6723e2b Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 02:03:30 +0000 Subject: [PATCH 092/107] fallbackt to BaseModelTesterConfig --- .../test_models_transformer_helios.py | 85 ++++++++++--------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index 48ed52c75f8c..fa104e68b7d3 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import pytest import torch @@ -22,6 +21,7 @@ from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( AttentionTesterMixin, + BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TorchCompileTesterMixin, @@ -32,13 +32,52 @@ enable_full_determinism() -class HeliosTransformer3DTesterConfig(ModelTesterMixin, unittest.TestCase): - model_class = HeliosTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True +class HeliosTransformer3DTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return HeliosTransformer3DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-helios-base-transformer" + + @property + def output_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def input_shape(self) -> tuple[int, ...]: + return (4, 2, 16, 16) + + @property + def main_input_name(self) -> str: + return "hidden_states" @property - def dummy_input(self): + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]: + return { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_dim": (4, 4, 4), + "has_multi_term_memory_patch": True, + "guidance_cross_attn": True, + "zero_history_timestep": True, + "is_amplify_history": False, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: batch_size = 1 num_channels = 4 num_frames = 2 @@ -73,40 +112,6 @@ def dummy_input(self): "latents_history_long": latents_history_long, } - @property - def input_shape(self) -> tuple[int, ...]: - return (4, 2, 16, 16) - - @property - def output_shape(self) -> tuple[int, ...]: - return (4, 2, 16, 16) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "patch_size": (1, 2, 2), - "num_attention_heads": 2, - "attention_head_dim": 12, - "in_channels": 4, - "out_channels": 4, - "text_dim": 16, - "freq_dim": 256, - "ffn_dim": 32, - "num_layers": 2, - "cross_attn_norm": True, - "qk_norm": "rms_norm_across_heads", - "rope_dim": (4, 4, 4), - "has_multi_term_memory_patch": True, - "guidance_cross_attn": True, - "zero_history_timestep": True, - "is_amplify_history": False, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"HeliosTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Helios Transformer 3D.""" From 85738439344ac1d7b2eda2e3937377d613f611a0 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 02:43:42 +0000 Subject: [PATCH 093/107] remove _maybe_expand_t2v_lora_for_i2v --- src/diffusers/loaders/lora_pipeline.py | 172 +++++++++---------------- 1 file changed, 62 insertions(+), 110 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 83865ba7c701..0e5c2c5da9a9 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4321,9 +4321,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class HeliosLoraLoaderMixin(LoraBaseMixin): +class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. + Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. """ _lora_loadable_modules = ["transformer"] @@ -4331,6 +4331,7 @@ class HeliosLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4388,6 +4389,7 @@ def lora_state_dict( return out @classmethod + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v def _maybe_expand_t2v_lora_for_i2v( cls, transformer: torch.nn.Module, @@ -4435,6 +4437,7 @@ def _maybe_expand_t2v_lora_for_i2v( return state_dict + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4501,7 +4504,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel def load_lora_into_transformer( cls, state_dict, @@ -4595,9 +4598,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): +class CogView4LoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -4605,7 +4608,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4648,10 +4651,6 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) - elif any(k.startswith("lora_unet_") for k in state_dict): - state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4662,56 +4661,7 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out - @classmethod - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v - def _maybe_expand_t2v_lora_for_i2v( - cls, - transformer: torch.nn.Module, - state_dict, - ): - if transformer.config.image_dim is None: - return state_dict - - target_device = transformer.device - - if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) - is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - has_bias = any(".lora_B.bias" in k for k in state_dict) - - if is_i2v_lora: - return state_dict - - for i in range(num_blocks): - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - # These keys should exist if the block `i` was part of the T2V LoRA. - ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" - ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" - - if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: - continue - - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device - ) - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device - ) - - # If the original LoRA had biases (indicated by has_bias) - # AND the specific reference bias key exists for this block. - - ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" - if has_bias and ref_key_lora_B_bias in state_dict: - ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( - ref_lora_B_bias_tensor, - device=target_device, - ) - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4738,47 +4688,23 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - state_dict=state_dict, - ) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) - if load_into_transformer_2: - if not hasattr(self, "transformer_2"): - raise AttributeError( - f"'{type(self).__name__}' object has no attribute transformer_2" - "Note that Wan2.1 models do not have a transformer_2 component." - "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." - ) - self.load_lora_into_transformer( - state_dict, - transformer=self.transformer_2, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - else: - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( cls, state_dict, @@ -4872,9 +4798,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class CogView4LoraLoaderMixin(LoraBaseMixin): +class HeliosLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. + Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -4882,7 +4808,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4925,6 +4850,10 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4935,7 +4864,6 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4962,23 +4890,47 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( cls, state_dict, From 7754a664fd690462d98064449e02b51f13f137bb Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 02:58:49 +0000 Subject: [PATCH 094/107] fix HeliosLoraLoaderMixin --- src/diffusers/loaders/lora_pipeline.py | 329 ++++++++++++------------- 1 file changed, 152 insertions(+), 177 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0e5c2c5da9a9..5d10f596f2e6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3440,9 +3440,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): +class HeliosLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. + Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -3492,6 +3492,10 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -3499,14 +3503,9 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) - if is_original_hunyuan_video: - state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -3533,7 +3532,6 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") @@ -3549,7 +3547,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( cls, state_dict, @@ -3643,9 +3641,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class Lumina2LoraLoaderMixin(LoraBaseMixin): +class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`]. + Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -3702,10 +3700,9 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - # conversion. - non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) - if non_diffusers: - state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) + is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) + if is_original_hunyuan_video: + state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict return out @@ -3753,7 +3750,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( cls, state_dict, @@ -3819,7 +3816,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: list[str] = ["transformer"], @@ -3839,7 +3836,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): r""" See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. @@ -3847,9 +3844,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class KandinskyLoraLoaderMixin(LoraBaseMixin): +class Lumina2LoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`Kandinsky5Transformer3DModel`], + Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -3857,7 +3854,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -3907,6 +3903,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + # conversion. + non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) + if non_diffusers: + state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) + out = (state_dict, metadata) if return_lora_metadata else state_dict return out @@ -3953,7 +3954,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( cls, state_dict, @@ -4019,7 +4020,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: list[str] = ["transformer"], @@ -4039,7 +4040,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): r""" See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. @@ -4047,16 +4048,17 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class WanLoraLoaderMixin(LoraBaseMixin): +class KandinskyLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. + Load LoRA layers into [`Kandinsky5Transformer3DModel`], """ - _lora_loadable_modules = ["transformer", "transformer_2"] + _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4099,10 +4101,6 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) - elif any(k.startswith("lora_unet_") for k in state_dict): - state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4113,54 +4111,7 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out - @classmethod - def _maybe_expand_t2v_lora_for_i2v( - cls, - transformer: torch.nn.Module, - state_dict, - ): - if transformer.config.image_dim is None: - return state_dict - - target_device = transformer.device - - if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) - is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - has_bias = any(".lora_B.bias" in k for k in state_dict) - - if is_i2v_lora: - return state_dict - - for i in range(num_blocks): - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - # These keys should exist if the block `i` was part of the T2V LoRA. - ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" - ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" - - if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: - continue - - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device - ) - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device - ) - - # If the original LoRA had biases (indicated by has_bias) - # AND the specific reference bias key exists for this block. - - ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" - if has_bias and ref_key_lora_B_bias in state_dict: - ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( - ref_lora_B_bias_tensor, - device=target_device, - ) - - return state_dict - + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4187,47 +4138,23 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - state_dict=state_dict, - ) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) - if load_into_transformer_2: - if not hasattr(self, "transformer_2"): - raise AttributeError( - f"'{type(self).__name__}' object has no attribute transformer_2" - "Note that Wan2.1 models do not have a transformer_2 component." - "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." - ) - self.load_lora_into_transformer( - state_dict, - transformer=self.transformer_2, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - else: - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer( cls, state_dict, @@ -4321,17 +4248,16 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): +class WanLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. """ - _lora_loadable_modules = ["transformer"] + _lora_loadable_modules = ["transformer", "transformer_2"] transformer_name = TRANSFORMER_NAME @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4389,7 +4315,6 @@ def lora_state_dict( return out @classmethod - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v def _maybe_expand_t2v_lora_for_i2v( cls, transformer: torch.nn.Module, @@ -4437,7 +4362,6 @@ def _maybe_expand_t2v_lora_for_i2v( return state_dict - # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4504,7 +4428,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( cls, state_dict, @@ -4598,9 +4522,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class CogView4LoraLoaderMixin(LoraBaseMixin): +class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. + Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. """ _lora_loadable_modules = ["transformer"] @@ -4608,7 +4532,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4651,6 +4575,10 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4661,7 +4589,56 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + @classmethod + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v + def _maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + if transformer.config.image_dim is None: + return state_dict + + target_device = transformer.device + + if any(k.startswith("transformer.blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + has_bias = any(".lora_B.bias" in k for k in state_dict) + + if is_i2v_lora: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + # These keys should exist if the block `i` was part of the T2V LoRA. + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" + ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" + + if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: + continue + + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device + ) + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device + ) + + # If the original LoRA had biases (indicated by has_bias) + # AND the specific reference bias key exists for this block. + + ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" + if has_bias and ref_key_lora_B_bias in state_dict: + ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( + ref_lora_B_bias_tensor, + device=target_device, + ) + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4688,23 +4665,47 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel def load_lora_into_transformer( cls, state_dict, @@ -4798,9 +4799,9 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class HeliosLoraLoaderMixin(LoraBaseMixin): +class CogView4LoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`]. + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -4808,6 +4809,7 @@ class HeliosLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4850,10 +4852,6 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) - elif any(k.startswith("lora_unet_") for k in state_dict): - state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4864,6 +4862,7 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -4890,47 +4889,23 @@ def load_lora_weights( # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - state_dict=state_dict, - ) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) - if load_into_transformer_2: - if not hasattr(self, "transformer_2"): - raise AttributeError( - f"'{type(self).__name__}' object has no attribute transformer_2" - "Note that Wan2.1 models do not have a transformer_2 component." - "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." - ) - self.load_lora_into_transformer( - state_dict, - transformer=self.transformer_2, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - else: - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( cls, state_dict, From c3c7c9759fe872631626a6e4ffc6670b1147410d Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 03:08:14 +0000 Subject: [PATCH 095/107] update docs --- docs/source/en/api/pipelines/helios.md | 13 +++++++------ docs/source/en/using-diffusers/helios.md | 2 +- docs/source/zh/using-diffusers/helios.md | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 24ffae6e3939..0c711b56642e 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -24,7 +24,7 @@ [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Xinwei Huang, Xiao Yang, Li Yuan. -* We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios). +* We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page). The following Helios models are supported in Diffusers: @@ -366,11 +366,12 @@ misshapen limbs, fused fingers, still picture, messy background, three legs, man # For Text-to-Video prompt = """ -A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various -elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. -The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but -emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and -exploration. Medium shot focusing on the train window and the rushing scenery beyond. +A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue +and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with +a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, +allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades +of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and +the vivid colors of its surroundings. A close-up shot with dynamic movement. """ output = pipeline( diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index d95344003ef8..6de2e01a1c36 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -129,5 +129,5 @@ pipe.to("cuda") ## Resources Learn more about Helios with the following resources. -- A [video](https://www.youtube.com/watch?v=) demonstrating Helios's main features. +- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features. - The research paper, [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) for more details. diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 20a43b40da14..11a64b54706b 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -130,5 +130,5 @@ pipe.to("cuda") 通过以下资源了解有关 Helios 的更多信息: -- 一段 [视频](https://www.youtube.com/watch?v=) 演示了 Helios 的主要功能; +- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能; - 有关更多详细信息,请参阅研究论文 [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/)。 From 5d9ed985175d9e8224298a54ff3e3d4c0b89e654 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 06:27:05 +0000 Subject: [PATCH 096/107] use randn_tensor for test --- .../models/transformers/test_models_transformer_helios.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index fa104e68b7d3..563d5d9f9281 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -17,6 +17,7 @@ import torch from diffusers import HeliosTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( @@ -86,7 +87,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: text_encoder_embedding_dim = 16 sequence_length = 12 - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, num_channels, num_frames, height, width), + generator=self.generator, + device=torch_device, + ) timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device) From 8a63b0b76a4d26ae0c4ed03adc3c993d13276c1d Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 07:22:10 +0000 Subject: [PATCH 097/107] fix doc typo --- docs/source/en/using-diffusers/helios.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index 6de2e01a1c36..c98c31abedbb 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -13,8 +13,8 @@ specific language governing permissions and limitations under the License. [Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are: -- Without commonly used anti-drifting strategies (\eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), \ours generates minute-scale videos with high quality and strong coherence. -- Without standard acceleration techniques (\eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), \ours achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU. +- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence. +- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU. - Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models. This guide will walk you through using Helios for use cases. From 6098ebd36f7884301a43346e724d14022675efd9 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Tue, 3 Mar 2026 07:36:54 +0000 Subject: [PATCH 098/107] optimize code --- .../pipelines/helios/pipeline_helios.py | 20 +-- .../helios/pipeline_helios_pyramid.py | 42 +++-- tests/pipelines/helios/test_helios.py | 155 ++++++++++++++++++ 3 files changed, 191 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index f0fbd0b90b21..0902ec425e69 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -552,15 +552,6 @@ def __call__( history_sizes = sorted(history_sizes, reverse=True) # From big to small - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(self.vae.device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - self.vae.device, self.vae.dtype - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -587,6 +578,15 @@ def __call__( device = self._execution_device vae_dtype = self.vae.dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -778,7 +778,7 @@ def __call__( latents_history_1x.shape[-2], latents_history_1x.shape[-1], ), - device=latents_history_1x.device, + device=device, dtype=latents_history_1x.dtype, ) else: diff --git a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py index 4773f1befec5..40c1d65825ff 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py +++ b/src/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -438,16 +438,26 @@ def prepare_video_latents( latents = torch.cat(latents_chunks, dim=2) return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) - def sample_block_noise(self, batch_size, channel, num_frames, height, width, patch_size=(1, 2, 2)): + def sample_block_noise( + self, + batch_size, + channel, + num_frames, + height, + width, + patch_size: tuple[int, ...] = (1, 2, 2), + device: torch.device | None = None, + ): gamma = self.scheduler.config.gamma _, ph, pw = patch_size block_size = ph * pw - cov = torch.eye(block_size) * (1 + gamma) - torch.ones(block_size, block_size) * gamma - cov += torch.eye(block_size) * 1e-6 - dist = torch.distributions.MultivariateNormal( - torch.zeros(block_size, device=cov.device), covariance_matrix=cov + cov = ( + torch.eye(block_size, device=device) * (1 + gamma) + - torch.ones(block_size, block_size, device=device) * gamma ) + cov += torch.eye(block_size, device=device) * 1e-6 + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) noise = dist.sample((block_number,)) # [block number, block_size] @@ -593,15 +603,6 @@ def __call__( history_sizes = sorted(history_sizes, reverse=True) # From big to small pyramid_num_stages = len(pyramid_num_inference_steps_list) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(self.vae.device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - self.vae.device, self.vae.dtype - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -629,6 +630,15 @@ def __call__( device = self._execution_device vae_dtype = self.vae.dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -804,7 +814,7 @@ def __call__( latents_history_1x.shape[-2], latents_history_1x.shape[-1], ), - device=latents_history_1x.device, + device=device, dtype=latents_history_1x.dtype, ) else: @@ -906,7 +916,7 @@ def __call__( batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape noise = self.sample_block_noise( - batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size + batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device ) noise = noise.to(device=device, dtype=transformer_dtype) latents = alpha * latents + beta * noise # To fix the block artifact diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py index 86e1ecbf70be..63bc3bad4dd2 100644 --- a/tests/pipelines/helios/test_helios.py +++ b/tests/pipelines/helios/test_helios.py @@ -11,3 +11,158 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import gc +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, HeliosPipeline, HeliosScheduler, HeliosTransformer3DModel + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HeliosPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = HeliosScheduler(stage_range=[0, 1], stages=1, use_dynamic_shifting=True) + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = HeliosTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_dim=(4, 4, 4), + has_multi_term_memory_patch=True, + guidance_cross_attn=True, + zero_history_timestep=True, + is_amplify_history=False, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (33, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4529, 0.4527, 0.4499, 0.4542, 0.4528, 0.4524, 0.4531, 0.4534, 0.5328, + 0.5340, 0.5012, 0.5135, 0.5322, 0.5203, 0.5144, 0.5101]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Optional components not applicable for Helios") + def test_save_load_optional_components(self): + pass + + +@slow +@require_torch_accelerator +class HeliosPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_helios(self): + pass From fa927e331a916b991c0c11a207fa265ab2639789 Mon Sep 17 00:00:00 2001 From: yuanshenghai Date: Wed, 4 Mar 2026 02:00:37 +0000 Subject: [PATCH 099/107] mark torch.compile xfail --- tests/models/transformers/test_models_transformer_helios.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index 563d5d9f9281..c6d06b4f0beb 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -146,3 +146,7 @@ class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, Attentio class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin): """Torch compile tests for Helios Transformer 3D.""" + + @pytest.mark.xfail(reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079") + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() From d2118ef7abd6bf4cfbf1184c98a03dbd4e6be079 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Wed, 4 Mar 2026 03:44:34 +0000 Subject: [PATCH 100/107] change paper name --- docs/source/en/api/models/helios_transformer3d.md | 2 +- docs/source/en/api/pipelines/helios.md | 2 +- docs/source/en/using-diffusers/helios.md | 2 +- docs/source/zh/using-diffusers/helios.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/models/helios_transformer3d.md b/docs/source/en/api/models/helios_transformer3d.md index 75f1c536aa3e..5aa2826c32ec 100644 --- a/docs/source/en/api/models/helios_transformer3d.md +++ b/docs/source/en/api/models/helios_transformer3d.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # HeliosTransformer3DModel -A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) by Peking University & ByteDance & etc. +A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc. The model can be loaded with the following code snippet. diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 0c711b56642e..d00a24a8d105 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -22,7 +22,7 @@ # Helios -[Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Xinwei Huang, Xiao Yang, Li Yuan. +[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Xinwei Huang, Xiao Yang, Li Yuan. * We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page). diff --git a/docs/source/en/using-diffusers/helios.md b/docs/source/en/using-diffusers/helios.md index c98c31abedbb..8106f1c568f8 100644 --- a/docs/source/en/using-diffusers/helios.md +++ b/docs/source/en/using-diffusers/helios.md @@ -130,4 +130,4 @@ pipe.to("cuda") Learn more about Helios with the following resources. - Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features. -- The research paper, [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/) for more details. +- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details. diff --git a/docs/source/zh/using-diffusers/helios.md b/docs/source/zh/using-diffusers/helios.md index 11a64b54706b..5c4faed2ca2a 100644 --- a/docs/source/zh/using-diffusers/helios.md +++ b/docs/source/zh/using-diffusers/helios.md @@ -131,4 +131,4 @@ pipe.to("cuda") 通过以下资源了解有关 Helios 的更多信息: - [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能; -- 有关更多详细信息,请参阅研究论文 [Helios: 14B Real-Time Long Video Generation Model can be Cheaper, Faster but Keep Stronger than 1.3B ones](https://huggingface.co/papers/)。 +- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。 From 95a04089ff60bcb6fae0a606b97c4ccdf0756587 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 09:03:06 +0100 Subject: [PATCH 101/107] Make get_dummy_inputs deterministic using self.generator --- .../test_models_transformer_helios.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index c6d06b4f0beb..a6cf63d137c8 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -92,16 +92,30 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: generator=self.generator, device=torch_device, ) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ) indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device) indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device) indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device) indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device) - latents_history_short = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) - latents_history_mid = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) - latents_history_long = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to( - torch_device + latents_history_short = randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ) + latents_history_mid = randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ) + latents_history_long = randn_tensor( + (batch_size, num_channels, (num_frames - 1) * 4, height, width), + generator=self.generator, + device=torch_device, ) return { From 339afe57f01861f92af895f8526af500dfb6c324 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 09:09:40 +0100 Subject: [PATCH 102/107] Set less strict threshold for test_save_load_float16 test for Helios pipeline --- tests/pipelines/helios/test_helios.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pipelines/helios/test_helios.py b/tests/pipelines/helios/test_helios.py index 63bc3bad4dd2..b8ee99085036 100644 --- a/tests/pipelines/helios/test_helios.py +++ b/tests/pipelines/helios/test_helios.py @@ -139,6 +139,10 @@ def test_inference(self): generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + # Override to set a more lenient max diff threshold. + def test_save_load_float16(self): + super().test_save_load_float16(expected_max_diff=0.03) + @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass From 6ead7ad5e10ce872c703164ba06657d661099b94 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 09:10:25 +0100 Subject: [PATCH 103/107] make style and make quality --- tests/models/transformers/test_models_transformer_helios.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index a6cf63d137c8..c365c258e596 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -160,7 +160,9 @@ class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, Attentio class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin): """Torch compile tests for Helios Transformer 3D.""" - - @pytest.mark.xfail(reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079") + + @pytest.mark.xfail( + reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079" + ) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() From e53374d8b9a0fe898a318276f87b6ede50d8ef1f Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Wed, 4 Mar 2026 09:30:23 +0000 Subject: [PATCH 104/107] Preparation for merging --- 0_temp_helios_test/helios-base_i2v.sh | 14 - 0_temp_helios_test/helios-base_t2v.sh | 13 - 0_temp_helios_test/helios-base_v2v.sh | 14 - 0_temp_helios_test/helios-distilled_i2v.sh | 17 -- 0_temp_helios_test/helios-distilled_t2v.sh | 16 -- 0_temp_helios_test/helios-distilled_v2v.sh | 17 -- 0_temp_helios_test/helios-mid_i2v.sh | 16 -- 0_temp_helios_test/helios-mid_t2v.sh | 15 - 0_temp_helios_test/helios-mid_v2v.sh | 16 -- 0_temp_helios_test/infer_helios.py | 312 --------------------- 0_temp_helios_test/requirements.txt | 36 --- docs/source/en/api/pipelines/helios.md | 35 +-- 12 files changed, 19 insertions(+), 502 deletions(-) delete mode 100644 0_temp_helios_test/helios-base_i2v.sh delete mode 100644 0_temp_helios_test/helios-base_t2v.sh delete mode 100644 0_temp_helios_test/helios-base_v2v.sh delete mode 100644 0_temp_helios_test/helios-distilled_i2v.sh delete mode 100644 0_temp_helios_test/helios-distilled_t2v.sh delete mode 100644 0_temp_helios_test/helios-distilled_v2v.sh delete mode 100644 0_temp_helios_test/helios-mid_i2v.sh delete mode 100644 0_temp_helios_test/helios-mid_t2v.sh delete mode 100644 0_temp_helios_test/helios-mid_v2v.sh delete mode 100644 0_temp_helios_test/infer_helios.py delete mode 100644 0_temp_helios_test/requirements.txt diff --git a/0_temp_helios_test/helios-base_i2v.sh b/0_temp_helios_test/helios-base_i2v.sh deleted file mode 100644 index 44fcb203350c..000000000000 --- a/0_temp_helios_test/helios-base_i2v.sh +++ /dev/null @@ -1,14 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Base" \ - --transformer_path "BestWishYsh/Helios-Base" \ - --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ - --guidance_scale 5.0 \ - --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ - --output_folder "./output_helios/helios-base" - - - # --enable_compile \ - # --use_cfg_zero_star \ - # --use_zero_init \ - # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-base_t2v.sh b/0_temp_helios_test/helios-base_t2v.sh deleted file mode 100644 index 8626e439bfca..000000000000 --- a/0_temp_helios_test/helios-base_t2v.sh +++ /dev/null @@ -1,13 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Base" \ - --transformer_path "BestWishYsh/Helios-Base" \ - --sample_type "t2v" \ - --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ - --guidance_scale 5.0 \ - --output_folder "./output_helios/helios-base" - - - # --enable_compile \ - # --use_cfg_zero_star \ - # --use_zero_init \ - # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-base_v2v.sh b/0_temp_helios_test/helios-base_v2v.sh deleted file mode 100644 index 710a1b45eb74..000000000000 --- a/0_temp_helios_test/helios-base_v2v.sh +++ /dev/null @@ -1,14 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Base" \ - --transformer_path "BestWishYsh/Helios-Base" \ - --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ - --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ - --guidance_scale 5.0 \ - --output_folder "./output_helios/helios-base" - - - # --enable_compile \ - # --use_cfg_zero_star \ - # --use_zero_init \ - # --zero_steps 1 \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-distilled_i2v.sh b/0_temp_helios_test/helios-distilled_i2v.sh deleted file mode 100644 index 00c147a57f5a..000000000000 --- a/0_temp_helios_test/helios-distilled_i2v.sh +++ /dev/null @@ -1,17 +0,0 @@ -CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Distilled" \ - --transformer_path "BestWishYsh/Helios-Distilled" \ - --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ - --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ - --num_frames 240 \ - --guidance_scale 1.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 2 2 2 \ - --is_enable_stage3 \ - --is_amplify_first_chunk \ - --output_folder "./output_helios/helios-distilled" - - - # --pyramid_num_inference_steps_list 1 1 1 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-distilled_t2v.sh b/0_temp_helios_test/helios-distilled_t2v.sh deleted file mode 100644 index bbe95b9981eb..000000000000 --- a/0_temp_helios_test/helios-distilled_t2v.sh +++ /dev/null @@ -1,16 +0,0 @@ -CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Distilled" \ - --transformer_path "BestWishYsh/Helios-Distilled" \ - --sample_type "t2v" \ - --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ - --num_frames 240 \ - --guidance_scale 1.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 2 2 2 \ - --is_enable_stage3 \ - --is_amplify_first_chunk \ - --output_folder "./output_helios/helios-distilled" - - - # --pyramid_num_inference_steps_list 1 1 1 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-distilled_v2v.sh b/0_temp_helios_test/helios-distilled_v2v.sh deleted file mode 100644 index f6ecadcf0c82..000000000000 --- a/0_temp_helios_test/helios-distilled_v2v.sh +++ /dev/null @@ -1,17 +0,0 @@ -CUDA_VISIBLE_DEVICES=1 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Distilled" \ - --transformer_path "BestWishYsh/Helios-Distilled" \ - --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ - --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ - --num_frames 240 \ - --guidance_scale 1.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 2 2 2 \ - --is_enable_stage3 \ - --is_amplify_first_chunk \ - --output_folder "./output_helios/helios-distilled" - - - # --pyramid_num_inference_steps_list 1 1 1 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-mid_i2v.sh b/0_temp_helios_test/helios-mid_i2v.sh deleted file mode 100644 index 46de67456e8d..000000000000 --- a/0_temp_helios_test/helios-mid_i2v.sh +++ /dev/null @@ -1,16 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Mid" \ - --transformer_path "BestWishYsh/Helios-Mid" \ - --sample_type "i2v" \ - --image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg" \ - --prompt "A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and respect for nature’s might." \ - --guidance_scale 5.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 20 20 20 \ - --use_zero_init \ - --zero_steps 1 \ - --output_folder "./output_helios/helios-mid" - - - # --pyramid_num_inference_steps_list 17 17 17 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-mid_t2v.sh b/0_temp_helios_test/helios-mid_t2v.sh deleted file mode 100644 index c56394776550..000000000000 --- a/0_temp_helios_test/helios-mid_t2v.sh +++ /dev/null @@ -1,15 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Mid" \ - --transformer_path "BestWishYsh/Helios-Mid" \ - --sample_type "t2v" \ - --prompt "A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and the vivid colors of its surroundings. A close-up shot with dynamic movement." \ - --guidance_scale 5.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 20 20 20 \ - --use_zero_init \ - --zero_steps 1 \ - --output_folder "./output_helios/helios-mid" - - - # --pyramid_num_inference_steps_list 17 17 17 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/helios-mid_v2v.sh b/0_temp_helios_test/helios-mid_v2v.sh deleted file mode 100644 index c4eadade4370..000000000000 --- a/0_temp_helios_test/helios-mid_v2v.sh +++ /dev/null @@ -1,16 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python infer_helios.py \ - --base_model_path "BestWishYsh/Helios-Mid" \ - --transformer_path "BestWishYsh/Helios-Mid" \ - --sample_type "v2v" \ - --video_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4" \ - --prompt "A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery." \ - --guidance_scale 5.0 \ - --is_enable_stage2 \ - --pyramid_num_inference_steps_list 20 20 20 \ - --use_zero_init \ - --zero_steps 1 \ - --output_folder "./output_helios/helios-mid" - - - # --pyramid_num_inference_steps_list 17 17 17 \ - # --enable_compile \ \ No newline at end of file diff --git a/0_temp_helios_test/infer_helios.py b/0_temp_helios_test/infer_helios.py deleted file mode 100644 index 25432678ce14..000000000000 --- a/0_temp_helios_test/infer_helios.py +++ /dev/null @@ -1,312 +0,0 @@ -import os - - -os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" -os.environ["HF_PARALLEL_LOADING_WORKERS"] = "8" - -import argparse -import time - -import pandas as pd -import torch -import torch.distributed as dist -from tqdm import tqdm - -from diffusers import HeliosTransformer3DModel -from diffusers import HeliosPipeline, HeliosPyramidPipeline - -from diffusers import ContextParallelConfig -from diffusers.models import AutoencoderKLWan -from diffusers.utils import export_to_video, load_image, load_video - - -def parse_args(): - parser = argparse.ArgumentParser(description="Generate video with model") - - # === Model paths === - parser.add_argument("--base_model_path", type=str, default="BestWishYsh/Helios-Base") - parser.add_argument( - "--transformer_path", - type=str, - default="BestWishYsh/Helios-Base", - ) - parser.add_argument( - "--lora_path", - type=str, - default=None, - ) - parser.add_argument("--output_folder", type=str, default="./output_helios") - parser.add_argument("--use_default_loader", action="store_true") - parser.add_argument("--enable_compile", action="store_true") - parser.add_argument("--low_vram_mode", action="store_true") - parser.add_argument("--enable_parallelism", action="store_true") - - # === Generation parameters === - # environment - parser.add_argument("--debug_mode", action="store_true") - parser.add_argument( - "--sample_type", - type=str, - default="t2v", - choices=["t2v", "i2v", "v2v"], - ) - parser.add_argument( - "--weight_dtype", - type=str, - default="bf16", - choices=["bf16", "fp16", "fp32"], - help="Data type for model weights.", - ) - parser.add_argument("--seed", type=int, default=42, help="Seed for random number generator.") - # base - parser.add_argument("--height", type=int, default=384) - parser.add_argument("--width", type=int, default=640) - parser.add_argument("--num_frames", type=int, default=99) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--guidance_scale", type=float, default=5.0) - # cfg zero - parser.add_argument("--use_zero_init", action="store_true") - parser.add_argument("--zero_steps", type=int, default=1) - # stage 1 - parser.add_argument("--num_latent_frames_per_chunk", type=int, default=9) - # stage 2 - parser.add_argument("--is_enable_stage2", action="store_true") - parser.add_argument("--stage2_timestep_shift", type=float, default=1.0) - parser.add_argument("--stage2_scheduler_gamma", type=float, default=1 / 3) - parser.add_argument("--stage2_stage_range", type=int, nargs="+", default=[0, 1 / 3, 2 / 3, 1]) - parser.add_argument("--pyramid_num_inference_steps_list", type=int, nargs="+", default=[20, 20, 20]) - # stage 3 - parser.add_argument("--is_enable_stage3", action="store_true") - parser.add_argument("--is_skip_first_chunk", action="store_true") - parser.add_argument("--is_amplify_first_chunk", action="store_true") - - # === Prompts === - parser.add_argument("--use_interpolate_prompt", action="store_true") - parser.add_argument("--interpolation_steps", type=int, default=3) - parser.add_argument( - "--image_path", - type=str, - default=None, - ) - parser.add_argument( - "--video_path", - type=str, - default=None, - ) - parser.add_argument( - "--prompt", - type=str, - default="A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond.", - ) - parser.add_argument( - "--negative_prompt", - type=str, - default="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", - ) - parser.add_argument( - "--prompt_txt_path", - type=str, - default=None, - ) - parser.add_argument( - "--interactive_prompt_csv_path", - type=str, - default=None, - ) - parser.add_argument( - "--base_image_prompt_path", - type=str, - default=None, - ) - parser.add_argument( - "--image_prompt_csv_path", - type=str, - default=None, - ) - - return parser.parse_args() - - -def main(): - args = parse_args() - - assert not (args.low_vram_mode and args.enable_compile), ( - "low_vram_mode and enable_compile cannot be used together." - ) - - if args.weight_dtype == "fp32": - args.weight_dtype = torch.float32 - elif args.weight_dtype == "fp16": - args.weight_dtype = torch.float16 - else: - args.weight_dtype = torch.bfloat16 - - os.makedirs(args.output_folder, exist_ok=True) - - if dist.is_available() and "RANK" in os.environ: - dist.init_process_group(backend="nccl") - rank = dist.get_rank() - device = torch.device("cuda", rank % torch.cuda.device_count()) - world_size = dist.get_world_size() - torch.cuda.set_device(device) - assert world_size == 1 or not args.low_vram_mode, "low_vram_mode is only for single GPU." - else: - rank = 0 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - world_size = 1 - - prompt = None - image_path = None - video_path = None - interpolate_time_list = None - if args.sample_type == "t2v" and args.prompt is None: - prompt = "An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film." - elif args.sample_type == "i2v" and (args.image_path is None and args.prompt is None): - image_path = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" - ) - prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." - elif args.sample_type == "v2v" and (args.video_path is None and args.prompt is None): - video_path = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" - ) - prompt = "A robot standing on a mountain top. The sun is setting in the background." - else: - image_path = args.image_path - video_path = args.video_path - if args.interactive_prompt_csv_path is not None and args.use_interpolate_prompt: - with open(args.prompt, "r") as f: - lines = [line.strip() for line in f.readlines() if line.strip()] - interpolate_time_list = [] - prompt = [] - for line in lines: - parts = line.split(",", 1) - if len(parts) == 2: - time_value = int(parts[0].strip()) - prompt_text = parts[1].strip().strip('"') - - interpolate_time_list.append(time_value) - prompt.append(prompt_text) - else: - prompt = args.prompt - - transformer = HeliosTransformer3DModel.from_pretrained( - args.transformer_path, - subfolder="transformer", - torch_dtype=args.weight_dtype, - use_default_loader=args.use_default_loader, - ) - try: - transformer.set_attention_backend("_flash_3_hub") - except: - transformer.set_attention_backend("flash_hub") - - vae = AutoencoderKLWan.from_pretrained( - args.base_model_path, - subfolder="vae", - torch_dtype=torch.float32, - ) - if not args.is_enable_stage2: - pipe = HeliosPipeline.from_pretrained( - args.base_model_path, - transformer=transformer, - vae=vae, - torch_dtype=args.weight_dtype, - ) - else: - pipe = HeliosPyramidPipeline.from_pretrained( - args.base_model_path, - transformer=transformer, - vae=vae, - torch_dtype=args.weight_dtype, - ) - - if args.lora_path is not None: - pipe.load_lora_weights(args.lora_path, adapter_name="default") - pipe.set_adapters(["default"], adapter_weights=[1.0]) - - if args.enable_compile: - torch.backends.cudnn.benchmark = True - pipe.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False) - pipe.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False) - pipe.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) - - if args.low_vram_mode: - pipe.enable_group_offload( - onload_device=torch.device("cuda"), - offload_device=torch.device("cpu"), - # offload_type="leaf_level", - offload_type="block_level", - num_blocks_per_group=1, - use_stream=True, - record_stream=True, - ) - else: - pipe = pipe.to(device) - - if world_size > 1 and args.enable_parallelism: - # transformer.set_attention_backend("flash") - pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size)) - - with torch.no_grad(): - if not args.is_enable_stage2: - output = pipe( - prompt=prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - is_skip_first_chunk=args.is_skip_first_chunk, - # i2v - image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - ).frames[0] - else: - output = pipe( - prompt=prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, # 73 109 145 181 215 - guidance_scale=args.guidance_scale, - generator=torch.Generator(device="cuda").manual_seed(args.seed), - # stage 1 - history_sizes=[16, 2, 1], - num_latent_frames_per_chunk=args.num_latent_frames_per_chunk, - keep_first_frame=True, - is_skip_first_chunk=args.is_skip_first_chunk, - # stage 2 - pyramid_num_inference_steps_list=args.pyramid_num_inference_steps_list, - # stage 3 - is_amplify_first_chunk=args.is_amplify_first_chunk, - # cfg zero - use_zero_init=args.use_zero_init, - zero_steps=args.zero_steps, - # i2v - image=load_image(image_path).resize((args.width, args.height)) if image_path is not None else None, - # t2v - video=load_video(video_path) if video_path is not None else None, - ).frames[0] - - if not args.enable_parallelism or rank == 0: - file_count = len( - [f for f in os.listdir(args.output_folder) if os.path.isfile(os.path.join(args.output_folder, f))] - ) - output_path = os.path.join( - args.output_folder, f"{file_count:04d}_{args.sample_type}_{int(time.time())}.mp4" - ) - export_to_video(output, output_path, fps=24) - - print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB") - - -if __name__ == "__main__": - main() diff --git a/0_temp_helios_test/requirements.txt b/0_temp_helios_test/requirements.txt deleted file mode 100644 index b381119d3944..000000000000 --- a/0_temp_helios_test/requirements.txt +++ /dev/null @@ -1,36 +0,0 @@ -torch==2.7.1 -torchvision==0.22.1 -torchaudio==2.7.1 -triton==3.3.1 -# diffusers==0.36.0 -# transformers==4.57.6 -# sentence-transformers==5.2.3 -# git+https://github.com/SHYuanBest/diffusers.git@test -git+https://github.com/huggingface/transformers.git -git+https://github.com/huggingface/sentence-transformers.git -accelerate==1.12.0 -deepspeed==0.18.4 -peft==0.18.1 -huggingface-hub==1.4.1 -zstandard==0.25.0 -wandb==0.23.0 -video-reader-rs==0.4.1 -numpy<2.0.0 -opencv-python -gradio -spaces -moviepy -imageio-ffmpeg -ftfy -Jinja2 -einops -nvitop -packaging -ninja -omegaconf -mpi4py -hf-doc-builder -torchdata -kernels -loguru -tf_keras \ No newline at end of file diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index d00a24a8d105..c575621ac888 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -88,9 +88,9 @@ output = pipeline( negative_prompt=negative_prompt, num_frames=99, num_inference_steps=50, - use_dynamic_shifting=True, + guidance_scale=5.0, ).frames[0] -export_to_video(output, "output.mp4", fps=24) +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) ``` @@ -141,9 +141,9 @@ output = pipeline( negative_prompt=negative_prompt, num_frames=99, num_inference_steps=50, - use_dynamic_shifting=True, + guidance_scale=5.0, ).frames[0] -export_to_video(output, "output.mp4", fps=24) +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) ``` @@ -192,9 +192,9 @@ output = pipeline( negative_prompt=negative_prompt, num_frames=99, num_inference_steps=50, - use_dynamic_shifting=True, + guidance_scale=5.0, ).frames[0] -export_to_video(output, "output_t2v.mp4", fps=24) +export_to_video(output, "helios_base_t2v_output.mp4", fps=24) # For Image-to-Video prompt = """ @@ -213,9 +213,9 @@ output = pipeline( image=load_image(image_path).resize((640, 384)), num_frames=99, num_inference_steps=50, - use_dynamic_shifting=True, + guidance_scale=5.0, ).frames[0] -export_to_video(output, "output_i2v.mp4", fps=24) +export_to_video(output, "helios_base_i2v_output.mp4", fps=24) # For Video-to-Video prompt = """ @@ -233,9 +233,9 @@ output = pipeline( video=load_video(video_path), num_frames=99, num_inference_steps=50, - use_dynamic_shifting=True, + guidance_scale=5.0, ).frames[0] -export_to_video(output, "output_v2v.mp4", fps=24) +export_to_video(output, "helios_base_v2v_output.mp4", fps=24) ``` @@ -284,10 +284,11 @@ output = pipeline( negative_prompt=negative_prompt, num_frames=99, pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, use_zero_init=True, zero_steps=1, ).frames[0] -export_to_video(output, "output_t2v.mp4", fps=24) +export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24) # For Image-to-Video prompt = """ @@ -306,10 +307,11 @@ output = pipeline( image=load_image(image_path).resize((640, 384)), num_frames=99, pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, use_zero_init=True, zero_steps=1, ).frames[0] -export_to_video(output, "output_i2v.mp4", fps=24) +export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24) # For Video-to-Video prompt = """ @@ -327,10 +329,11 @@ output = pipeline( video=load_video(video_path), num_frames=99, pyramid_num_inference_steps_list=[20, 20, 20], + guidance_scale=5.0, use_zero_init=True, zero_steps=1, ).frames[0] -export_to_video(output, "output_v2v.mp4", fps=24) +export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24) ``` @@ -382,7 +385,7 @@ output = pipeline( guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] -export_to_video(output, "output_t2v.mp4", fps=24) +export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24) # For Image-to-Video prompt = """ @@ -404,7 +407,7 @@ output = pipeline( guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] -export_to_video(output, "output_i2v.mp4", fps=24) +export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24) # For Video-to-Video prompt = """ @@ -425,7 +428,7 @@ output = pipeline( guidance_scale=1.0, is_amplify_first_chunk=True, ).frames[0] -export_to_video(output, "output_v2v.mp4", fps=24) +export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24) ``` From dcbb1823605b85f237f725fa98cf896c86b8a2d0 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Wed, 4 Mar 2026 09:33:41 +0000 Subject: [PATCH 105/107] add torch.Generator --- docs/source/en/api/pipelines/helios.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index c575621ac888..86cdf3eb621e 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -89,6 +89,7 @@ output = pipeline( num_frames=99, num_inference_steps=50, guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_base_t2v_output.mp4", fps=24) ``` @@ -142,6 +143,7 @@ output = pipeline( num_frames=99, num_inference_steps=50, guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_base_t2v_output.mp4", fps=24) ``` @@ -193,6 +195,7 @@ output = pipeline( num_frames=99, num_inference_steps=50, guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_base_t2v_output.mp4", fps=24) @@ -214,6 +217,7 @@ output = pipeline( num_frames=99, num_inference_steps=50, guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_base_i2v_output.mp4", fps=24) @@ -234,6 +238,7 @@ output = pipeline( num_frames=99, num_inference_steps=50, guidance_scale=5.0, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_base_v2v_output.mp4", fps=24) ``` @@ -287,6 +292,7 @@ output = pipeline( guidance_scale=5.0, use_zero_init=True, zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24) @@ -310,6 +316,7 @@ output = pipeline( guidance_scale=5.0, use_zero_init=True, zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24) @@ -332,6 +339,7 @@ output = pipeline( guidance_scale=5.0, use_zero_init=True, zero_steps=1, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24) ``` @@ -380,10 +388,11 @@ the vivid colors of its surroundings. A close-up shot with dynamic movement. output = pipeline( prompt=prompt, negative_prompt=negative_prompt, - num_frames=99, + num_frames=240, pyramid_num_inference_steps_list=[2, 2, 2], guidance_scale=1.0, is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24) @@ -402,10 +411,11 @@ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, image=load_image(image_path).resize((640, 384)), - num_frames=99, + num_frames=240, pyramid_num_inference_steps_list=[2, 2, 2], guidance_scale=1.0, is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24) @@ -423,10 +433,11 @@ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, video=load_video(video_path), - num_frames=99, + num_frames=240, pyramid_num_inference_steps_list=[2, 2, 2], guidance_scale=1.0, is_amplify_first_chunk=True, + generator=torch.Generator("cuda").manual_seed(42), ).frames[0] export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24) ``` From b9efc1558678beb7cb6bda5418b9fe7f53c71199 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 4 Mar 2026 12:00:47 +0100 Subject: [PATCH 106/107] Fix HeliosPipelineOutput doc path --- docs/source/en/api/pipelines/helios.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 86cdf3eb621e..22db16491b99 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -462,4 +462,4 @@ export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24) ## HeliosPipelineOutput -[[autodoc]] pipelines.Helios.pipeline_output.HeliosPipelineOutput +[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput From ab8127d6e28b4fe2783edbac2f9a8454905125ba Mon Sep 17 00:00:00 2001 From: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:14:43 +0800 Subject: [PATCH 107/107] Fix Helios related (optimize docs & remove redudant) (#13210) * fix docs * remove redudant * remove redudant * fix group offload * Removed fixes for group offload --- docs/source/en/api/pipelines/helios.md | 6 +++--- .../models/transformers/transformer_helios.py | 19 +++++++++++-------- .../pipelines/helios/pipeline_helios.py | 15 +++++++++------ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/docs/source/en/api/pipelines/helios.md b/docs/source/en/api/pipelines/helios.md index 22db16491b99..81559b24c071 100644 --- a/docs/source/en/api/pipelines/helios.md +++ b/docs/source/en/api/pipelines/helios.md @@ -22,7 +22,7 @@ # Helios -[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Xinwei Huang, Xiao Yang, Li Yuan. +[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan. * We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality. We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page). @@ -360,10 +360,10 @@ import torch from diffusers import AutoModel, HeliosPyramidPipeline from diffusers.utils import export_to_video, load_video, load_image -vae = AutoModel.from_pretrained("BestWishYsh//Helios-Distilled", subfolder="vae", torch_dtype=torch.float32) +vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32) pipeline = HeliosPyramidPipeline.from_pretrained( - "BestWishYsh//Helios-Distilled", + "BestWishYsh/Helios-Distilled", vae=vae, torch_dtype=torch.bfloat16 ) diff --git a/src/diffusers/models/transformers/transformer_helios.py b/src/diffusers/models/transformers/transformer_helios.py index 43b0327273b9..9f3ef047d98d 100644 --- a/src/diffusers/models/transformers/transformer_helios.py +++ b/src/diffusers/models/transformers/transformer_helios.py @@ -536,7 +536,14 @@ class HeliosTransformer3DModel( """ _supports_gradient_checkpointing = True - _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _skip_layerwise_casting_patterns = [ + "patch_embedding", + "patch_short", + "patch_mid", + "patch_long", + "condition_embedder", + "norm", + ] _no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"] _keep_in_fp32_modules = [ "time_embedder", @@ -594,18 +601,17 @@ def __init__( # 2. Initial Multi Term Memory Patch self.zero_history_timestep = zero_history_timestep - self.inner_dim = inner_dim if has_multi_term_memory_patch: - self.patch_short = nn.Conv3d(in_channels, self.inner_dim, kernel_size=patch_size, stride=patch_size) + self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) self.patch_mid = nn.Conv3d( in_channels, - self.inner_dim, + inner_dim, kernel_size=tuple(2 * p for p in patch_size), stride=tuple(2 * p for p in patch_size), ) self.patch_long = nn.Conv3d( in_channels, - self.inner_dim, + inner_dim, kernel_size=tuple(4 * p for p in patch_size), stride=tuple(4 * p for p in patch_size), ) @@ -683,7 +689,6 @@ def forward( # 3. Process short history latents if latents_history_short is not None and indices_latents_history_short is not None: - latents_history_short = latents_history_short.to(hidden_states) latents_history_short = self.patch_short(latents_history_short) _, _, _, H1, W1 = latents_history_short.shape latents_history_short = latents_history_short.flatten(2).transpose(1, 2) @@ -701,7 +706,6 @@ def forward( # 4. Process mid history latents if latents_history_mid is not None and indices_latents_history_mid is not None: - latents_history_mid = latents_history_mid.to(hidden_states) latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) latents_history_mid = self.patch_mid(latents_history_mid) latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) @@ -721,7 +725,6 @@ def forward( # 5. Process long history latents if latents_history_long is not None and indices_latents_history_long is not None: - latents_history_long = latents_history_long.to(hidden_states) latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) latents_history_long = self.patch_long(latents_history_long) latents_history_long = latents_history_long.flatten(2).transpose(1, 2) diff --git a/src/diffusers/pipelines/helios/pipeline_helios.py b/src/diffusers/pipelines/helios/pipeline_helios.py index 0902ec425e69..87a8600badab 100644 --- a/src/diffusers/pipelines/helios/pipeline_helios.py +++ b/src/diffusers/pipelines/helios/pipeline_helios.py @@ -815,6 +815,9 @@ def __call__( timestep = t.expand(latents.shape[0]) latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latent_model_input, @@ -824,9 +827,9 @@ def __call__( indices_latents_history_short=indices_latents_history_short, indices_latents_history_mid=indices_latents_history_mid, indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -841,9 +844,9 @@ def __call__( indices_latents_history_short=indices_latents_history_short, indices_latents_history_mid=indices_latents_history_mid, indices_latents_history_long=indices_latents_history_long, - latents_history_short=latents_history_short.to(transformer_dtype), - latents_history_mid=latents_history_mid.to(transformer_dtype), - latents_history_long=latents_history_long.to(transformer_dtype), + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, attention_kwargs=attention_kwargs, return_dict=False, )[0]