From 9720735f2f2d0a4801eb80e7a737cc3e122b1947 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 4 Mar 2026 15:45:56 +0000 Subject: [PATCH 1/2] Distribute VAE video encoding across processes in CogVideoX LoRA training Signed-off-by: jiqing-feng --- examples/cogvideo/train_cogvideox_lora.py | 40 +++++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 4e22d3f8727d..4b212271597e 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1232,22 +1232,48 @@ def load_model_hook(models, input_dir): id_token=args.id_token, ) - def encode_video(video, bar): - bar.update(1) + def encode_video(video): video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = vae.encode(video).latent_dist return latent_dist + # Distribute video encoding across processes: each process only encodes its own shard + num_videos = len(train_dataset.instance_videos) + num_procs = accelerator.num_processes + local_rank = accelerator.process_index + local_count = len(range(local_rank, num_videos, num_procs)) + progress_encode_bar = tqdm( - range(0, len(train_dataset.instance_videos)), - desc="Loading Encode videos", + range(local_count), + desc="Encoding videos", + disable=not accelerator.is_local_main_process, ) - train_dataset.instance_videos = [ - encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos - ] + + encoded_videos = [None] * num_videos + for i, video in enumerate(train_dataset.instance_videos): + if i % num_procs == local_rank: + encoded_videos[i] = encode_video(video) + progress_encode_bar.update(1) progress_encode_bar.close() + # Broadcast encoded latent distributions so every process has the full set + if num_procs > 1: + import torch.distributed as dist + from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + + ref_params = next(v for v in encoded_videos if v is not None).parameters + for i in range(num_videos): + src = i % num_procs + if encoded_videos[i] is not None: + params = encoded_videos[i].parameters.contiguous() + else: + params = torch.empty_like(ref_params) + dist.broadcast(params, src=src) + encoded_videos[i] = DiagonalGaussianDistribution(params) + + train_dataset.instance_videos = encoded_videos + def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] prompts = [example["instance_prompt"] for example in examples] From 075eabcecec1687e77a2e1bc017811b7a4d2bbb9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 4 Mar 2026 08:37:46 +0000 Subject: [PATCH 2/2] Apply style fixes --- examples/cogvideo/train_cogvideox_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 4b212271597e..e08143f98a5c 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1260,6 +1260,7 @@ def encode_video(video): # Broadcast encoded latent distributions so every process has the full set if num_procs > 1: import torch.distributed as dist + from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution ref_params = next(v for v in encoded_videos if v is not None).parameters