From aa602ac4831d44a3bab2d7d90f62096e5146ed59 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 12 Dec 2025 07:52:33 +0100 Subject: [PATCH 01/86] Initial LTX 2.0 transformer implementation --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_ltx2.py | 1206 +++++++++++++++++ 4 files changed, 1211 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_ltx2.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e69d334fdb8f..97ba02e2d03d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -236,6 +236,7 @@ "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", + "LTX2VideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -969,6 +970,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, + LTX2VideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..b387bd817c2d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -102,6 +102,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] + _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] @@ -209,6 +210,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, + LTX2VideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..cc8aff81425f 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel + from .transformer_ltx2 import LTX2VideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py new file mode 100644 index 000000000000..57d71a3eb6a2 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -0,0 +1,1206 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. This is typically a video (spatiotemporal) output. + audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`): + The audio output of the audiovisual model. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb) + key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +@maybe_allow_in_graph +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 2. Prompt Cross-Attention + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + ) + + # 4. Feedforward layers + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + a2v_cross_attention_mask: Optional[torch.Tensor] = None, + v2a_cross_attention_mask: Optional[torch.Tensor] = None, + use_video_self_attn: bool = True, + use_audio_self_attn: bool = True, + use_a2v_cross_attn: bool = True, + use_v2a_cross_attn: bool = True, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + if use_video_self_attn: + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + if use_audio_self_attn: + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + if use_a2v_cross_attn or use_v2a_cross_attn: + norm_hidden_states = self.norm3(hidden_states) + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + if use_a2v_cross_attn: + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + if use_v2a_cross_attn: + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where + the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32 ,32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + self.base_num_frames = base_num_frames + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original + pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, + num_patches, 2) where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent + frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, + coords: Optional[torch.Tensor] = None, + batch_size: Optional[int] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 25.0, + shift: int = 0, + device: Optional[Union[str, torch.device]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if coords is not None: + device = device or coords.device + batch_size = batch_size or coords.size(0) + else: + device = device or "cpu" + batch_size = batch_size or 1 + + # 1. Calculate the coordinate grid with respect to data space for the given modality (video, audio). + if coords is None and self.modality == "video": + coords = self.prepare_video_coords( + batch_size, + num_frames, + height, + width, + device=device, + ) + # Scale the temporal coordinates by the video FPS + coords[:, 0, ...] = coords[:, 0, ...] / fps + elif coords is None and self.modality == "audio": + coords = self.prepare_audio_coords( + batch_size, + num_frames, + device=device, + shift=shift, + ) + # Number of spatiotemporal dimensions (3 for video, 1 for audio) + num_pos_dims = coords.shape[1] + + # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 3. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 4. Create a 1D grid of frequencies for RoPE + start = 1.0 + end = self.theta + freqs = self.theta ** torch.linspace( + start=math.log(start, self.theta), + end=math.log(end, self.theta), + steps=self.dim // num_rope_elems, + device=device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + + # 5. Tensor-vector outer product between pos ids tensor of shape [B, 3, num_patches] and freqs vector of shape + # self.dim // num_elems + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, 3, num_patches, self.dim // num_elems] + freqs = freqs.transpose(1, 2).flatten(2) # [B, num_patches, self.dim // 2] + # freqs = freqs.transpose(-1, -2).flatten(2) # [B, 3, num_patches * self.dim // num_elems]??? + + # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +@maybe_allow_in_graph +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTXVideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: Optional[int] = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: Tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: Optional[int] = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + causal_offset: int = 1, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 25.0, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape (batch_size, num_video_tokens, in_channels). + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels). + encoder_hidden_states (`torch.Tensor`): + Input text embeddings of shape TODO. + timesteps (`torch.Tensor`): + Timestep information of shape (batch_size, num_train_timesteps). + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device) + + video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + # Scale timestep + timestep = timestep * timestep_scale_multiplier + timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift = self.av_cross_attn_audio_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate = self.av_cross_attn_audio_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1]) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + # TODO: does the audio prompt embedding start from the same text embeddings as the video one? + audio_encoder_hidden_states = self.audio_caption_projection(encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) From b3096c3c9eaf24a9778e3c30162ad22c6ae4af8f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 04:55:41 +0100 Subject: [PATCH 02/86] Add tests for LTX 2 transformer model --- .../test_models_transformer_ltx2.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_ltx2.py diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py new file mode 100644 index 000000000000..d67789acca5c --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTX2VideoTransformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + # Common + batch_size = 2 + # NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video + # and audio of equal duration + num_frames = 2 + + # Video + num_channels = 4 + height = 16 + width = 16 + + # Audio + audio_num_channels = 2 + num_mel_bins = 2 + + # Text + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) + audio_hidden_states = torch.randn( + (batch_size, num_frames, audio_num_channels * num_mel_bins) + ).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "audio_hidden_states": audio_hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "audio_encoder_hidden_states": audio_encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, + "fps": 25.0, + } + + @property + def input_shape(self): + return (512, 4) + + @property + def output_shape(self): + return (512, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 8, + "audio_cross_attention_dim": 16, + "num_layers": 2, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTX2VideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() From 980591de53bac7425234ebe8b295bd0375828ee8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 04:57:23 +0100 Subject: [PATCH 03/86] Get LTX 2 transformer tests working --- .../models/transformers/transformer_ltx2.py | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 57d71a3eb6a2..93b59dec51b0 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -577,6 +577,7 @@ def __init__( # Audio-specific self.sampling_rate = sampling_rate self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) self.scale_factors = scale_factors self.theta = theta @@ -657,6 +658,7 @@ def prepare_audio_coords( batch_size: int, num_frames: int, device: torch.device, + fps: float = 25.0, shift: int = 0, ) -> torch.Tensor: """ @@ -682,9 +684,11 @@ def prepare_audio_coords( """ # 1. Generate coordinates in the frame (time) dimension. + audio_duration_s = num_frames / fps + latent_frames = int(audio_duration_s * self.audio_latents_per_second) # Always compute rope in fp32 grid_f = torch.arange( - start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + start=shift, end=latent_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device ) # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid @@ -748,10 +752,11 @@ def forward( device=device, shift=shift, ) - # Number of spatiotemporal dimensions (3 for video, 1 for audio) + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] - # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries + # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index if coords.ndim == 4: coords_start, coords_end = coords.chunk(2, dim=-1) coords = (coords_start + coords_end) / 2.0 @@ -762,8 +767,9 @@ def forward( max_positions = (self.base_num_frames, self.base_height, self.base_width) elif self.modality == "audio": max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) - # Number of spatiotemporal dimensions (3 for video, 1 for audio) times 2 for cos, sin + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin num_rope_elems = num_pos_dims * 2 # 4. Create a 1D grid of frequencies for RoPE @@ -778,11 +784,10 @@ def forward( ) freqs = freqs * math.pi / 2.0 - # 5. Tensor-vector outer product between pos ids tensor of shape [B, 3, num_patches] and freqs vector of shape - # self.dim // num_elems - freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, 3, num_patches, self.dim // num_elems] - freqs = freqs.transpose(1, 2).flatten(2) # [B, num_patches, self.dim // 2] - # freqs = freqs.transpose(-1, -2).flatten(2) # [B, 3, num_patches * self.dim // num_elems]??? + # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) @@ -888,7 +893,7 @@ def __init__( # 1. Patchification input projections self.proj_in = nn.Linear(in_channels, inner_dim) - self.audio_proj_in = nn.Linear(audio_in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) # 2. Prompt embeddings self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -990,6 +995,10 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1015,6 +1024,7 @@ def forward( hidden_states: torch.Tensor, audio_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, num_frames: Optional[int] = None, @@ -1077,9 +1087,13 @@ def forward( video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + print(f"Video RoPE cos shape: {video_rotary_emb[0].shape} | sin shape: {video_rotary_emb[1].shape}") + print(f"Audio RoPE cos shape: {audio_rotary_emb[0].shape} | sin shape: {audio_rotary_emb[1].shape}") video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + print(f"Video CA RoPE cos shape: {video_cross_attn_rotary_emb[0].shape} | sin shape: {video_cross_attn_rotary_emb[1].shape}") + print(f"Audio CA RoPE cos shape: {audio_cross_attn_rotary_emb[0].shape} | sin shape: {audio_cross_attn_rotary_emb[1].shape}") # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) @@ -1110,12 +1124,12 @@ def forward( audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) # 3.2. Prepare global modality cross attention modulation parameters - video_cross_attn_scale_shift = self.av_cross_attn_video_scale_shift( + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) - video_cross_attn_a2v_gate = self.av_cross_attn_video_a2v_gate( + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=hidden_states.dtype, @@ -1123,12 +1137,12 @@ def forward( video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) - audio_cross_attn_scale_shift = self.av_cross_attn_audio_scale_shift( + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - audio_cross_attn_v2a_gate = self.av_cross_attn_audio_a2v_gate( + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, @@ -1137,13 +1151,12 @@ def forward( audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) # 4. Prepare prompt embeddings - # TODO: does the audio prompt embedding start from the same text embeddings as the video one? - audio_encoder_hidden_states = self.audio_caption_projection(encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) - encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + # 5. Run transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -1152,6 +1165,7 @@ def forward( hidden_states, audio_hidden_states, encoder_hidden_states, + audio_encoder_hidden_states, temb, temb_audio, video_cross_attn_scale_shift, @@ -1169,6 +1183,7 @@ def forward( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, temb=temb, temb_audio=temb_audio, temb_ca_scale_shift=video_cross_attn_scale_shift, From e100b8f2a3f7d7d88ac0ca6c33a47a0dd215d8f1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 10:34:11 +0100 Subject: [PATCH 04/86] Rename LTX 2 compile test class to have LTX2 --- tests/models/transformers/test_models_transformer_ltx2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index d67789acca5c..fc089e6190ae 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -108,7 +108,7 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): +class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = LTX2VideoTransformer3DModel def prepare_init_args_and_inputs_for_common(self): From 780fb61d32a7a664eec978a7d7c98784394386cc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 13 Dec 2025 10:37:24 +0100 Subject: [PATCH 05/86] Remove RoPE debug print statements --- src/diffusers/models/transformers/transformer_ltx2.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 93b59dec51b0..f74f608457f3 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1087,13 +1087,9 @@ def forward( video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) - print(f"Video RoPE cos shape: {video_rotary_emb[0].shape} | sin shape: {video_rotary_emb[1].shape}") - print(f"Audio RoPE cos shape: {audio_rotary_emb[0].shape} | sin shape: {audio_rotary_emb[1].shape}") video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) - print(f"Video CA RoPE cos shape: {video_cross_attn_rotary_emb[0].shape} | sin shape: {video_cross_attn_rotary_emb[1].shape}") - print(f"Audio CA RoPE cos shape: {audio_cross_attn_rotary_emb[0].shape} | sin shape: {audio_cross_attn_rotary_emb[1].shape}") # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) From 5765759cd33c693de60db2a4805990a12002fd61 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 03:38:34 +0100 Subject: [PATCH 06/86] Get LTX 2 transformer compile tests passing --- src/diffusers/models/transformers/transformer_ltx2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index f74f608457f3..c3bb7a00a4cc 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -246,7 +246,6 @@ def forward( return hidden_states -@maybe_allow_in_graph class LTX2VideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). @@ -802,7 +801,6 @@ def forward( return cos_freqs, sin_freqs -@maybe_allow_in_graph class LTX2VideoTransformer3DModel( ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin ): @@ -834,7 +832,7 @@ class LTX2VideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] - _repeated_blocks = ["LTXVideoTransformerBlock"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] _cp_plan = { "": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), From aeecc4d7125e111a7dba491036dbf1e26f759ecd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 06:38:57 +0100 Subject: [PATCH 07/86] Fix LTX 2 transformer shape errors --- .../models/transformers/transformer_ltx2.py | 16 ++++++++-------- .../transformers/test_models_transformer_ltx2.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c3bb7a00a4cc..c1ad5f180fea 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -297,7 +297,7 @@ def __init__( qk_norm=qk_norm, ) - self.audio_norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn1 = LTX2Attention( query_dim=audio_dim, heads=audio_num_attention_heads, @@ -322,7 +322,7 @@ def __init__( qk_norm=qk_norm, ) - self.audio_norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_attn2 = LTX2Attention( query_dim=audio_dim, cross_attention_dim=audio_cross_attention_dim, @@ -349,7 +349,7 @@ def __init__( ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video - self.video_to_audio_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.video_to_audio_attn = LTX2Attention( query_dim=audio_dim, cross_attention_dim=dim, @@ -365,13 +365,13 @@ def __init__( self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.ff = FeedForward(dim, activation_fn=activation_fn) - self.audio_norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters # Self-Attention / Feedforward AdaLayerNorm-Zero mod params self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) @@ -459,8 +459,8 @@ def forward( # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: - norm_hidden_states = self.norm3(hidden_states) - norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) # Combine global and per-layer cross attention modulation parameters # Video @@ -1114,7 +1114,7 @@ def forward( batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - temb_audio = temb.view(batch_size, -1, temb_audio.size(-1)) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) # 3.2. Prepare global modality cross attention modulation parameters diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index fc089e6190ae..c382a63eaaf8 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -94,8 +94,8 @@ def prepare_init_args_and_inputs_for_common(self): "audio_in_channels": 4, "audio_out_channels": 4, "audio_num_attention_heads": 2, - "audio_attention_head_dim": 8, - "audio_cross_attention_dim": 16, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, "num_layers": 2, "qk_norm": "rms_norm_across_heads", "caption_channels": 16, From a5f2d2da6c4131449a9726e01342b20e12ab2110 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 15 Dec 2025 07:09:42 +0100 Subject: [PATCH 08/86] Initial script to convert LTX 2 transformer to diffusers --- scripts/convert_ltx2_to_diffusers.py | 318 +++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 scripts/convert_ltx2_to_diffusers.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py new file mode 100644 index 000000000000..286e2aed42e1 --- /dev/null +++ b/scripts/convert_ltx2_to_diffusers.py @@ -0,0 +1,318 @@ +import argparse +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import LTX2VideoTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + + +def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + # Produces a transformer of the same size as used in test_models_transformer_ltx2.py + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "vae_scale_factors": (8, 32 ,32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 2, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 16, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "causal_offset": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32 ,32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "causal_offset": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.replace(prefix, "")] = param + return model_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--original_state_dict_repo_id", + default="diffusers-internal-dev/new-ltx-model", + type=str, + help="HF Hub repo id with LTX 2.0 checkpoint", + ) + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.", + ) + parser.add_argument( + "--version", + type=str, + default="2.0", + choices=["test", "2.0"], + help="Version of the LTX 2.0 model", + ) + + parser.add_argument( + "--combined_filename", + default="ltx-av-step-1932500-interleaved-new-vae.safetensors", + type=str, + help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", + ) + parser.add_argument("--vae_prefix", default="vae.", type=str) + parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str) + parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str) + parser.add_argument("--vocoder_prefix", default="vocoder.", type=str) + + parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set") + parser.add_argument( + "--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set" + ) + parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set") + parser.add_argument( + "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set" + ) + + parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") + parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") + parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") + parser.add_argument( + "--full_pipeline", + action="store_true", + help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)", + ) + + parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +def main(args): + vae_dtype = DTYPE_MAPPING[args.vae_dtype] + audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype] + dit_dtype = DTYPE_MAPPING[args.dit_dtype] + vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype] + + combined_ckpt = None + load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline]) + if args.combined_filename is not None and load_combined_models: + combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) + + if args.vae or args.full_pipeline: + pass + + if args.audio_vae or args.full_pipeline: + pass + + if args.dit or args.full_pipeline: + if args.dit_filename is not None: + original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + elif combined_ckpt is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) + if not args.full_pipeline: + transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.vocoder or args.full_pipeline: + pass + + if args.full_pipeline: + pass + + +if __name__ == '__main__': + args = get_args() + main(args) From d86f89ddea76952279af1da5ff188562f615325f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 07:58:12 +0100 Subject: [PATCH 09/86] Add more LTX 2 transformer audio arguments --- .../models/transformers/transformer_ltx2.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index c1ad5f180fea..2ce6106eecfc 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -394,6 +394,7 @@ def forward( ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, a2v_cross_attention_mask: Optional[torch.Tensor] = None, v2a_cross_attention_mask: Optional[torch.Tensor] = None, use_video_self_attn: bool = True, @@ -453,7 +454,7 @@ def forward( norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, - attention_mask=encoder_attention_mask, + attention_mask=audio_encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states @@ -1024,11 +1025,13 @@ def forward( encoder_hidden_states: torch.Tensor, audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, - encoder_attention_mask: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, fps: float = 25.0, + audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, audio_coords: Optional[torch.Tensor] = None, timestep_scale_multiplier: int = 1000, @@ -1075,13 +1078,17 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + batch_size = hidden_states.size(0) # 1. Prepare RoPE positional embeddings if video_coords is None: video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1171,6 +1178,7 @@ def forward( video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, + audio_encoder_attention_mask, ) else: hidden_states, audio_hidden_states = block( @@ -1189,6 +1197,7 @@ def forward( ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, ) # 6. Output layers (including unpatchification) From 57a8b9c3300201cc9609b882c3229bce3eb5cfeb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:38:03 +0100 Subject: [PATCH 10/86] Allow LTX 2 transformer to be loaded from local path for conversion --- scripts/convert_ltx2_to_diffusers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 286e2aed42e1..312559dbee47 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -192,6 +192,26 @@ def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: return original_state_dict +def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]: + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + ckpt_path = filename + + _, ext = os.path.splitext(ckpt_path) + if ext in [".safetensors", ".sft"]: + state_dict = safetensors.torch.load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + return state_dict + + def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: # Ensure that the key prefix ends with a dot (.) if not prefix.endswith("."): @@ -299,7 +319,7 @@ def main(args): if args.dit or args.full_pipeline: if args.dit_filename is not None: - original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) + original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) elif combined_ckpt is not None: original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) From a7bc052e899936396dfcd08b0a5a88abe2088b5f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:44:02 +0100 Subject: [PATCH 11/86] Improve dummy inputs and add test for LTX 2 transformer consistency --- .../test_models_transformer_ltx2.py | 122 +++++++++++++++++- 1 file changed, 116 insertions(+), 6 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index c382a63eaaf8..0bf08f161d43 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -17,7 +17,7 @@ import torch -from diffusers import LTX2VideoTransformer3DModel +from diffusers import LTX2VideoTransformer3DModel, attention_backend from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -35,16 +35,15 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): # Common batch_size = 2 - # NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video - # and audio of equal duration - num_frames = 2 # Video + num_frames = 2 num_channels = 4 height = 16 width = 16 # Audio + audio_num_frames = 9 audio_num_channels = 2 num_mel_bins = 2 @@ -54,12 +53,12 @@ def dummy_input(self): hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) audio_hidden_states = torch.randn( - (batch_size, num_frames, audio_num_channels * num_mel_bins) + (batch_size, audio_num_frames, audio_num_channels * num_mel_bins) ).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) return { "hidden_states": hidden_states, @@ -71,6 +70,7 @@ def dummy_input(self): "num_frames": num_frames, "height": height, "width": width, + "audio_num_frames": audio_num_frames, "fps": 25.0, } @@ -107,6 +107,116 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"LTX2VideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + torch.manual_seed(seed) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Calculate dummy inputs in a custom manner to ensure compatibility with original code + batch_size = 2 + num_frames = 9 + latent_frames = 2 + text_embedding_dim = 16 + text_seq_len = 16 + fps = 25.0 + sampling_rate = 16000.0 + hop_length = 160.0 + + sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") + timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + num_channels = 4 + latent_height = 4 + latent_width = 4 + hidden_states = torch.randn( + (batch_size, num_channels, latent_frames, latent_height, latent_width), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify video latents (with patch_size (1, 1, 1)) + hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + audio_num_channels = 2 + num_mel_bins = 2 + latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + audio_hidden_states = torch.randn( + (batch_size, audio_num_channels, latent_length, num_mel_bins), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify audio latents + audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + audio_encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + inputs_dict = { + "hidden_states": hidden_states.to(device=torch_device), + "audio_hidden_states": audio_hidden_states.to(device=torch_device), + "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + "timestep": timestep, + "num_frames": latent_frames, + "height": latent_height, + "width": latent_width, + "audio_num_frames": num_frames, + "fps": 25.0, + } + + model = self.model_class.from_pretrained( + "diffusers-internal-dev/dummy-ltx2", + subfolder="transformer", + device_map="cpu", + ) + # torch.manual_seed(seed) + # model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with attention_backend("native"): + with torch.no_grad(): + output = model(**inputs_dict) + + video_output, audio_output = output.to_tuple() + + self.assertIsNotNone(video_output) + self.assertIsNotNone(audio_output) + + # input & output have to have the same shape + video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # Check against expected slice + # fmt: off + video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # fmt: on + + video_output_flat = video_output.cpu().flatten().float() + video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + print(f"Video Expected Slice: {video_expected_slice}") + print(f"Video Generated Slice: {video_generated_slice}") + self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + audio_output_flat = audio_output.cpu().flatten().float() + audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + print(f"Audio Expected Slice: {audio_expected_slice}") + print(f"Audio Generated Slice: {audio_generated_slice}") + self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = LTX2VideoTransformer3DModel From bda3ff13dbc895365fb6b3fcbb800df5f1844ecf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:53:43 +0100 Subject: [PATCH 12/86] Fix LTX 2 transformer bugs so consistency test passes --- .../models/transformers/transformer_ltx2.py | 22 +++++++++++++------ .../test_models_transformer_ltx2.py | 4 ---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2ce6106eecfc..ea9bca115e99 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -456,7 +456,7 @@ def forward( query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) - hidden_states = hidden_states + attn_hidden_states + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention if use_a2v_cross_attn or use_v2a_cross_attn: @@ -557,7 +557,7 @@ def __init__( base_width: int = 2048, sampling_rate: int = 16000, hop_length: int = 160, - scale_factors: Tuple[int, ...] = (8, 32 ,32), + scale_factors: Tuple[int, ...] = (8, 32, 32), theta: float = 10000.0, causal_offset: int = 1, modality: str = "video", @@ -594,6 +594,7 @@ def prepare_video_coords( height: int, width: int, device: torch.device, + fps: float = 25.0, ) -> torch.Tensor: """ Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original @@ -651,6 +652,9 @@ def prepare_video_coords( # and clamp to keep the first-frame timestamps causal and non-negative. pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + return pixel_coords def prepare_audio_coords( @@ -742,15 +746,15 @@ def forward( height, width, device=device, + fps=fps, ) - # Scale the temporal coordinates by the video FPS - coords[:, 0, ...] = coords[:, 0, ...] / fps elif coords is None and self.modality == "audio": coords = self.prepare_audio_coords( batch_size, num_frames, device=device, shift=shift, + fps=fps, ) # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] @@ -1086,9 +1090,13 @@ def forward( # 1. Prepare RoPE positional embeddings if video_coords is None: - video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, hidden_states.device) + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, audio_hidden_states.device) + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device, fps=fps + ) video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1104,7 +1112,7 @@ def forward( # Scale timestep timestep = timestep * timestep_scale_multiplier timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier - + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer # modulation with scale_shift_table (and similarly for audio) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 0bf08f161d43..6c0b97c58906 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -207,14 +207,10 @@ def test_ltx2_consistency(self, seed=0, dtype=torch.float32): video_output_flat = video_output.cpu().flatten().float() video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) - print(f"Video Expected Slice: {video_expected_slice}") - print(f"Video Generated Slice: {video_generated_slice}") self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) audio_output_flat = audio_output.cpu().flatten().float() audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) - print(f"Audio Expected Slice: {audio_expected_slice}") - print(f"Audio Generated Slice: {audio_generated_slice}") self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) From 269cf7b40d3b5100637990907627b2254bf1897a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 10:51:34 +0100 Subject: [PATCH 13/86] Initial implementation of LTX 2.0 video VAE --- scripts/convert_ltx2_to_diffusers.py | 137 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_ltx2.py | 1437 +++++++++++++++++ 5 files changed, 1577 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 312559dbee47..dfec0262deb8 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available @@ -35,6 +35,32 @@ "k_norm": "norm_k", } +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -68,6 +94,11 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) "adaln_single": convert_ltx2_transformer_adaln_single, } +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -180,6 +211,102 @@ def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) return transformer +def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": True, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -312,7 +439,13 @@ def main(args): combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) if args.vae or args.full_pipeline: - pass + if args.vae_filename is not None: + original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) + elif combined_ckpt is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + if not args.full_pipeline: + vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: pass diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 97ba02e2d03d..71cad3425f0b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -194,6 +194,7 @@ "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", + "AutoencoderKLLTX2Video", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -928,6 +929,7 @@ AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b387bd817c2d..3f4e49015b59 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,7 @@ _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -153,6 +154,7 @@ AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93cd7..ca0cac1a57b7 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 000000000000..9f65c9980d18 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1437 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d +class LTXVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.conv1 = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXVideoCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 + ) + # self.conv_shortcut = LTXVideoCausalConv3d( + # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + # ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: + hidden_states = inputs + + # Normalize over the channels dimension (dim 1), which is not the last dim + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTXVideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = out_channels + + self.conv_in = LTXVideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTXVideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + else: + hidden_states = self.mid_block(hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTXVideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + self.decoder = LTXVideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, return_dict=return_dict) + + dec = self.decoder(z, temb) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, return_dict=True).sample + else: + decoded = self.decoder(tile, temb) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb) + if not return_dict: + return (dec.sample,) + return dec From baf23e2da3f0816d1ebe870ccd66249fa3e5ceaa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:14:45 +0100 Subject: [PATCH 14/86] Explicitly specify temporal and spatial VAE scale factors when converting --- scripts/convert_ltx2_to_diffusers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index dfec0262deb8..85fa169af3ac 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -241,6 +241,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT @@ -274,6 +276,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "resnet_norm_eps": 1e-6, "encoder_causal": True, "decoder_causal": True, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, }, } rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT From 5b950d6fefae4035d835e539c7b2676008ba43fc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:30:15 +0100 Subject: [PATCH 15/86] Add initial LTX 2.0 video VAE tests --- src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 9f65c9980d18..755b92c10a02 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -633,7 +633,7 @@ def forward( # Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is # different, as is the layers_per_block (the 2.0 VAE is bigger) -class LTXVideoEncoder3d(nn.Module): +class LTX2VideoEncoder3d(nn.Module): r""" The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent representation. @@ -779,7 +779,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 -class LTXVideoDecoder3d(nn.Module): +class LTX2VideoDecoder3d(nn.Module): r""" The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. @@ -1011,7 +1011,7 @@ def __init__( ) -> None: super().__init__() - self.encoder = LTXVideoEncoder3d( + self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -1024,7 +1024,7 @@ def __init__( resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, ) - self.decoder = LTXVideoDecoder3d( + self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=decoder_block_out_channels, From 491aae08d84d66a3db73f2fdeca96f109f28c4a7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 17 Dec 2025 11:39:09 +0100 Subject: [PATCH 16/86] Add initial LTX 2.0 video VAE tests (part 2) --- .../test_models_autoencoder_ltx2_video.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/autoencoders/test_models_autoencoder_ltx2_video.py diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py new file mode 100644 index 000000000000..703ba54f89a1 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLLTX2Video + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Video + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + input_dict = {"sample": image} + return input_dict + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTX2VideoEncoder3d", + "LTX2VideoDecoder3d", + "LTX2VideoDownBlock3D", + "LTX2VideoMidBlock3d", + "LTX2VideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass From a748975a7c9a658b218694e10df6f9694e48078a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 07:02:38 +0100 Subject: [PATCH 17/86] Get diffusers implementation on par with official LTX 2.0 video VAE implementation --- scripts/convert_ltx2_to_diffusers.py | 8 +- .../autoencoders/autoencoder_kl_ltx2.py | 276 +++++++++++------- .../test_models_autoencoder_ltx2_video.py | 5 +- 3 files changed, 174 insertions(+), 115 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 85fa169af3ac..25a04e789347 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -240,7 +240,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, @@ -275,7 +277,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 755b92c10a02..6e7b4d324fc4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,8 +29,8 @@ from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d -class LTXVideoCausalConv3d(nn.Module): +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -39,14 +39,12 @@ def __init__( stride: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, - padding_mode: str = "zeros", - is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.is_causal = is_causal self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) @@ -63,13 +61,13 @@ def __init__( dilation=dilation, groups=groups, padding=padding, - padding_mode=padding_mode, + padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: time_kernel_size = self.kernel_size[0] - if self.is_causal: + if causal: pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) else: @@ -81,7 +79,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable class LTX2VideoResnetBlock3d(nn.Module): r""" A 3D ResNet block used in the LTX 2.0 audiovisual model. @@ -111,9 +110,9 @@ def __init__( eps: float = 1e-6, elementwise_affine: bool = False, non_linearity: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -122,14 +121,20 @@ def __init__( self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXVideoCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm3 = None @@ -140,9 +145,6 @@ def __init__( self.conv_shortcut = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ) - # self.conv_shortcut = LTXVideoCausalConv3d( - # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal - # ) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -155,7 +157,11 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) def forward( - self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: hidden_states = inputs @@ -168,7 +174,7 @@ def forward( hidden_states = hidden_states * (1 + scale_1) + shift_1 hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) if self.per_channel_scale1 is not None: spatial_shape = hidden_states.shape[-2:] @@ -184,7 +190,7 @@ def forward( hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) if self.per_channel_scale2 is not None: spatial_shape = hidden_states.shape[-2:] @@ -203,15 +209,14 @@ def forward( return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d class LTXVideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -220,16 +225,15 @@ def __init__( out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) residual = ( @@ -241,7 +245,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = residual.unflatten(1, (-1, self.group_size)) residual = residual.mean(dim=2) - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = ( hidden_states.unflatten(4, (-1, self.stride[2])) .unflatten(3, (-1, self.stride[1])) @@ -253,16 +257,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -272,16 +275,15 @@ def __init__( out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape if self.residual: @@ -293,7 +295,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = residual.repeat(1, repeats, 1, 1, 1) residual = residual[:, :, self.stride[0] - 1 :] - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) @@ -342,8 +344,8 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -358,7 +360,7 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -369,30 +371,39 @@ def __init__( if downsample_type == "conv": self.downsamplers.append( - LTXVideoCausalConv3d( + LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatial": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "temporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatiotemporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) @@ -403,18 +414,19 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, causal=causal) return hidden_states @@ -449,9 +461,9 @@ def __init__( dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -468,9 +480,9 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -482,6 +494,7 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" @@ -497,9 +510,9 @@ def forward( for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -540,11 +553,11 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -562,9 +575,9 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) self.upsamplers = None @@ -574,9 +587,9 @@ def __init__( LTXVideoUpsampler3d( out_channels * upscale_factor, stride=(2, 2, 2), - is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, ) ] ) @@ -590,9 +603,9 @@ def __init__( dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -604,9 +617,10 @@ def forward( hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states, temb, generator) + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) if self.time_embedder is not None: temb = self.time_embedder( @@ -620,13 +634,13 @@ def forward( if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, causal=causal) for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -682,21 +696,23 @@ def __init__( patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal output_channel = out_channels - self.conv_in = LTXVideoCausalConv3d( + self.conv_in = LTX2VideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # down blocks @@ -713,8 +729,8 @@ def __init__( num_layers=layers_per_block[i], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"Unknown down block type: {down_block_types[i]}") @@ -726,19 +742,23 @@ def __init__( in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size @@ -748,28 +768,29 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + causal = causal or self.is_causal hidden_states = hidden_states.reshape( batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) # Thanks for driving me insane with the weird patching order :( hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, causal=causal) if torch.is_grad_enabled() and self.gradient_checkpointing: for down_block in self.down_blocks: - hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) else: for down_block in self.down_blocks: - hidden_states = down_block(hidden_states) + hidden_states = down_block(hidden_states, causal=causal) - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) last_channel = hidden_states[:, -1:] last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) @@ -817,17 +838,19 @@ def __init__( patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, - is_causal: bool = True, + is_causal: bool = False, inject_noise: Tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (True, True, True), upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -837,17 +860,21 @@ def __init__( upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] - self.conv_in = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.mid_block = LTX2VideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, - is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) # up blocks @@ -863,11 +890,11 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks.append(up_block) @@ -875,8 +902,12 @@ def __init__( # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) # timestep embedding @@ -890,22 +921,26 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.conv_in(hidden_states) + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) if self.timestep_scale_multiplier is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) for up_block in self.up_blocks: - hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) else: - hidden_states = self.mid_block(hidden_states, temb) + hidden_states = self.mid_block(hidden_states, temb, causal=causal) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb) + hidden_states = up_block(hidden_states, temb, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) @@ -923,7 +958,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) p = self.patch_size p_t = self.patch_size_t @@ -1006,6 +1041,8 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", spatial_compression_ratio: int = None, temporal_compression_ratio: int = None, ) -> None: @@ -1023,6 +1060,7 @@ def __init__( patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, ) self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, @@ -1038,6 +1076,7 @@ def __init__( inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -1120,22 +1159,22 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def _encode(self, x: torch.Tensor) -> torch.Tensor: + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: - return self._temporal_tiled_encode(x) + return self._temporal_tiled_encode(x, causal=causal) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) + return self.tiled_encode(x, causal=causal) - enc = self.encoder(x) + enc = self.encoder(x, causal=causal) return enc @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -1150,10 +1189,10 @@ def encode( [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self._encode(x) + h = self._encode(x, causal=causal) posterior = DiagonalGaussianDistribution(h) if not return_dict: @@ -1161,7 +1200,11 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio @@ -1169,12 +1212,12 @@ def _decode( tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: - return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, temb, return_dict=return_dict) + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) - dec = self.decoder(z, temb) + dec = self.decoder(z, temb, causal=causal) if not return_dict: return (dec,) @@ -1183,7 +1226,11 @@ def _decode( @apply_forward_hook def decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1201,13 +1248,13 @@ def decode( if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z, temb).sample + decoded = self._decode(z, temb, causal=causal).sample if not return_dict: return (decoded,) @@ -1238,7 +1285,7 @@ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b - def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. Args: @@ -1267,7 +1314,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: row = [] for j in range(0, width, self.tile_sample_stride_width): time = self.encoder( - x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, ) row.append(time) @@ -1290,7 +1338,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: return enc def tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1324,7 +1372,9 @@ def tiled_decode( for i in range(0, height, tile_latent_stride_height): row = [] for j in range(0, width, tile_latent_stride_width): - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) row.append(time) rows.append(row) @@ -1349,7 +1399,7 @@ def tiled_decode( return DecoderOutput(sample=dec) - def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: batch_size, num_channels, num_frames, height, width = x.shape latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 @@ -1361,9 +1411,9 @@ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: for i in range(0, num_frames, self.tile_sample_stride_num_frames): tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): - tile = self.tiled_encode(tile) + tile = self.tiled_encode(tile, causal=causal) else: - tile = self.encoder(tile) + tile = self.encoder(tile, causal=causal) if i > 0: tile = tile[:, :, 1:, :, :] row.append(tile) @@ -1380,7 +1430,7 @@ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: return enc def _temporal_tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 @@ -1395,9 +1445,9 @@ def _temporal_tiled_decode( for i in range(0, num_frames, tile_latent_stride_num_frames): tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): - decoded = self.tiled_decode(tile, temb, return_dict=True).sample + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample else: - decoded = self.decoder(tile, temb) + decoded = self.decoder(tile, temb, causal=causal) if i > 0: decoded = decoded[:, :, :-1, :, :] row.append(decoded) @@ -1422,16 +1472,18 @@ def forward( sample: torch.Tensor, temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample - posterior = self.encode(x).latent_dist + posterior = self.encode(x, causal=encoder_causal).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, temb) + dec = self.decode(z, temb, causal=decoder_causal) if not return_dict: return (dec.sample,) return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 703ba54f89a1..25984d621ac0 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -55,7 +55,10 @@ def get_autoencoder_kl_ltx_video_config(self): "patch_size": 1, "patch_size_t": 1, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", } @property From c6a11a553038e503f5f76f5bb667030a04504277 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 19 Dec 2025 12:17:10 +0100 Subject: [PATCH 18/86] Initial LTX 2.0 vocoder implementation --- scripts/convert_ltx2_to_diffusers.py | 65 ++++++++- src/diffusers/pipelines/ltx2/vocoder.py | 173 ++++++++++++++++++++++++ 2 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/vocoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 25a04e789347..f2e879c06562 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -10,6 +10,7 @@ from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder CTX = init_empty_weights if is_accelerate_available() else nullcontext @@ -61,6 +62,13 @@ "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -99,6 +107,8 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -315,6 +325,53 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + } + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -468,7 +525,13 @@ def main(args): transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) if args.vocoder or args.full_pipeline: - pass + if args.vocoder_filename is not None: + original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename) + elif combined_ckpt is not None: + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version) + if not args.full_pipeline: + vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) if args.full_pipeline: pass diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 000000000000..c3b3c1f36796 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,173 @@ +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding_mode + ) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=stride, + dilation=1, + padding=padding_mode + ) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4], + upsample_factors: List[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: List[int] = [3, 7, 11], + resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states From 6c56954fa876cd0aef5054d1eb0dc3ad684ebaa3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 20 Dec 2025 02:40:38 +0100 Subject: [PATCH 19/86] Use RMSNorm implementation closer to original for LTX 2.0 video VAE --- .../autoencoders/autoencoder_kl_ltx2.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 6e7b4d324fc4..df59e2d74868 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,6 +29,38 @@ from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + # Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime class LTX2VideoCausalConv3d(nn.Module): def __init__( @@ -120,7 +152,7 @@ def __init__( self.nonlinearity = get_activation(non_linearity) - self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm1 = PerChannelRMSNorm() self.conv1 = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, @@ -128,7 +160,7 @@ def __init__( spatial_padding_mode=spatial_padding_mode, ) - self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.norm2 = PerChannelRMSNorm() self.dropout = nn.Dropout(dropout) self.conv2 = LTX2VideoCausalConv3d( in_channels=out_channels, @@ -165,8 +197,7 @@ def forward( ) -> torch.Tensor: hidden_states = inputs - # Normalize over the channels dimension (dim 1), which is not the last dim - hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm1(hidden_states) if self.scale_shift_table is not None: temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] @@ -183,7 +214,7 @@ def forward( )[None] hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] - hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm2(hidden_states) if self.scale_shift_table is not None: hidden_states = hidden_states * (1 + scale_2) + shift_2 @@ -746,7 +777,7 @@ def __init__( ) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -788,7 +819,7 @@ def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> hidden_states = self.mid_block(hidden_states, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states, causal=causal) @@ -900,7 +931,7 @@ def __init__( self.up_blocks.append(up_block) # out - self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.norm_out = PerChannelRMSNorm() self.conv_act = nn.SiLU() self.conv_out = LTX2VideoCausalConv3d( in_channels=output_channel, @@ -942,7 +973,7 @@ def forward( for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, causal=causal) - hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.norm_out(hidden_states) if self.time_embedder is not None: temb = self.time_embedder( From b34ddb1736377a9b2e01dea5408b99a8cc147f28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 12:23:31 +0530 Subject: [PATCH 20/86] start audio decoder. --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 655 ++++++++++++++++++ 1 file changed, 655 insertions(+) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..98d8a53e2359 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,655 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from typing import Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput + + +LATENT_DOWNSAMPLE_FACTOR = 4 +SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"} + + +AudioLatentShape = namedtuple( + "AudioLatentShape", + [ + "batch", + "channels", + "frames", + "mel_bins", + ], +) + + +def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]: + normalized = "none" if causality_axis is None else str(causality_axis).lower() + if normalized not in SUPPORTED_CAUSAL_AXES: + raise NotImplementedError( + f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}" + ) + return None if normalized == "none" else normalized + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + padding: Optional[Tuple[int, int, int, int]] = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: Optional[str] = None, +) -> nn.Module: + if causality_axis is not None: + return LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis + ) + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module: + if normtype == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == "pixel": + return LTX2AudioPixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioUpsample(nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: Optional[str] = "height", + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioPerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + the entire dataset and stored in model's checkpoint under AudioVAE state_dict + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + batch, time, _ = audio_latents.shape + channels = output_shape.channels + freq = output_shape.mel_bins + return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int, + output_channels: int, + num_res_blocks: int, + attn_resolutions: Set[int], + in_channels: int, + resolution: int, + latent_channels: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = None, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + + self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = resolved_causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = nn.SiLU() + self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention) + self.up, final_block_channels = self._build_up_path( + initial_block_channels=base_block_channels, + dropout=dropout, + resamp_with_conv=True, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + hidden_features = self.conv_in(sample) + hidden_features = self._run_mid_layers(hidden_features) + hidden_features = self._run_upsampling_path(hidden_features) + decoded_output = self._finalize_output(hidden_features) + + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module: + mid = nn.Module() + mid.block_1 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity() + mid.block_2 = LTX2AudioResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + return mid + + def _build_up_path( + self, + initial_block_channels: int, + dropout: float, + resamp_with_conv: bool, + ) -> tuple[nn.ModuleList, int]: + up_modules = nn.ModuleList() + block_in = initial_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor: + features = self.mid.block_1(features, temb=None) + features = self.mid.attn_1(features) + return self.mid.block_2(features, temb=None) + + def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + features = block(features, temb=None) + if stage.attn: + features = stage.attn[block_idx](features) + + if level != 0 and hasattr(stage, "upsample"): + features = stage.upsample(features) + + return features + + def _finalize_output(self, features: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return features + + hidden = self.norm_out(features) + hidden = self.non_linearity(hidden) + decoded = self.conv_out(hidden) + return torch.tanh(decoded) if self.tanh_out else decoded + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE. Currently, only implements the decoder. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: Tuple[int] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Tuple[int] = (8, 16, 32), + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: Optional[str] = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = None, + ) -> None: + super().__init__() + + resolved_causality_axis = _resolve_causality_axis(causality_axis) + attn_resolution_set = set(attn_resolutions) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=resolved_causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + self.use_slicing = False + + @apply_forward_hook + def encode( + self, + x: torch.Tensor, + return_dict: bool = True, + ): + raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`." + ) From f4c2435d61f03e6e97bcbafec1ece6b5bcf50357 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 12:25:36 +0530 Subject: [PATCH 21/86] init registration. --- src/diffusers/models/autoencoders/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93cd7..032bbe412352 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,7 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage From e54cd6bb1d40f806a7b227500da7514a091e07d2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:03:40 +0530 Subject: [PATCH 22/86] up --- scripts/test_ltx2_audio_conversion.py | 106 ++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 scripts/test_ltx2_audio_conversion.py diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py new file mode 100644 index 000000000000..251d0b64e969 --- /dev/null +++ b/scripts/test_ltx2_audio_conversion.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +""" +Quick check that an LTX2 audio decoder checkpoint converts cleanly to the diffusers +`AutoencoderKLLTX2Audio` layout and produces matching outputs on dummy data. +""" + +import argparse +import sys +from pathlib import Path + +import torch + + +def convert_state_dict(state_dict: dict) -> dict: + converted = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + new_key = key + if new_key.startswith("decoder."): + new_key = new_key[len("decoder.") :] + converted[f"decoder.{new_key}"] = value + return converted + + +def load_original_decoder(original_repo: Path, device: torch.device, dtype: torch.dtype, checkpoint_path: Path | None): + ltx_core_src = original_repo / "ltx-core" / "src" + if not ltx_core_src.exists(): + raise FileNotFoundError(f"ltx-core sources not found under {ltx_core_src}") + sys.path.insert(0, str(ltx_core_src)) + + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator + + decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) + + if checkpoint_path is not None: + raw_state = torch.load(checkpoint_path, map_location=device) + state_dict = raw_state.get("state_dict", raw_state) + decoder_state: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + trimmed = key + if trimmed.startswith("audio_vae.decoder."): + trimmed = trimmed[len("audio_vae.decoder.") :] + elif trimmed.startswith("decoder."): + trimmed = trimmed[len("decoder.") :] + decoder_state[trimmed] = value + decoder.load_state_dict(decoder_state, strict=False) + + decoder.eval() + return decoder + + +def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): + from diffusers.models.autoencoders.autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio + + model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) + model.eval() + return model + + +def main() -> None: + parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") + parser.add_argument( + "--original-repo", + type=Path, + default=Path("/Users/sayakpaul/Downloads/ltx-2"), + help="Path to the original ltx-2 repository (needed to import ltx-core).", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Optional path to an original checkpoint containing decoder weights.", + ) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--batch", type=int, default=2) + args = parser.parse_args() + + device = torch.device(args.device) + dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} + dtype = dtype_map[args.dtype] + + original_decoder = load_original_decoder(args.original_repo, device, dtype, args.checkpoint) + diffusers_model = build_diffusers_decoder(device, dtype) + + converted_state = convert_state_dict(original_decoder.state_dict()) + diffusers_model.load_state_dict(converted_state, strict=False) + + levels = len(diffusers_model.decoder.channel_multipliers) + latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + dummy = torch.randn(args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype) + + with torch.no_grad(): + original_out = original_decoder(dummy) + diffusers_out = diffusers_model.decode(dummy).sample + + torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) + max_diff = (diffusers_out - original_out).abs().max().item() + print(f"Conversion successful. Max diff: {max_diff:.6f}") + + +if __name__ == "__main__": + main() From 907896d533ae7089c30cd98790975c4ad5dd6b48 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:41:41 +0530 Subject: [PATCH 23/86] simplify and clean up --- scripts/test_ltx2_audio_conversion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 251d0b64e969..649b6d06d625 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -91,7 +91,9 @@ def main() -> None: levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) - dummy = torch.randn(args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype) + dummy = torch.randn( + args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype + ) with torch.no_grad(): original_out = original_decoder(dummy) From 4904fd6fa520894d586ec740bc2a10177e306883 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 13:46:58 +0530 Subject: [PATCH 24/86] up --- scripts/test_ltx2_audio_conversion.py | 88 ++++++++++++--------------- 1 file changed, 39 insertions(+), 49 deletions(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 649b6d06d625..f9554782c9f3 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -1,14 +1,19 @@ -#!/usr/bin/env python -""" -Quick check that an LTX2 audio decoder checkpoint converts cleanly to the diffusers -`AutoencoderKLLTX2Audio` layout and produces matching outputs on dummy data. -""" - import argparse -import sys from pathlib import Path +import safetensors.torch import torch +from huggingface_hub import hf_hub_download + + +def download_checkpoint( + repo_id="diffusers-internal-dev/new-ltx-model", + filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", + device="cuda", +): + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + ckpt = safetensors.torch.load_file(ckpt_path, device=device)["audio_vae"] + return ckpt def convert_state_dict(state_dict: dict) -> dict: @@ -23,71 +28,57 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(original_repo: Path, device: torch.device, dtype: torch.dtype, checkpoint_path: Path | None): - ltx_core_src = original_repo / "ltx-core" / "src" - if not ltx_core_src.exists(): - raise FileNotFoundError(f"ltx-core sources not found under {ltx_core_src}") - sys.path.insert(0, str(ltx_core_src)) - +def load_original_decoder(device: torch.device, dtype: torch.dtype): from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator - decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) - - if checkpoint_path is not None: - raw_state = torch.load(checkpoint_path, map_location=device) - state_dict = raw_state.get("state_dict", raw_state) - decoder_state: dict[str, torch.Tensor] = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - continue - trimmed = key - if trimmed.startswith("audio_vae.decoder."): - trimmed = trimmed[len("audio_vae.decoder.") :] - elif trimmed.startswith("decoder."): - trimmed = trimmed[len("decoder.") :] - decoder_state[trimmed] = value - decoder.load_state_dict(decoder_state, strict=False) + with torch.device("meta"): + decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) + original_state_dict = download_checkpoint(device) + + decoder_state_dict = {} + for key, value in original_state_dict.items(): + if not isinstance(value, torch.Tensor): + continue + trimmed = key + if trimmed.startswith("audio_vae.decoder."): + trimmed = trimmed[len("audio_vae.decoder.") :] + elif trimmed.startswith("decoder."): + trimmed = trimmed[len("decoder.") :] + decoder_state_dict[trimmed] = value + decoder.load_state_dict(decoder_state_dict, strict=True, assign=True) decoder.eval() return decoder def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): - from diffusers.models.autoencoders.autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio + from diffusers.models.autoencoders import AutoencoderKLLTX2Audio + + with torch.device("meta"): + model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) - model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) model.eval() return model +@torch.no_grad() def main() -> None: parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") - parser.add_argument( - "--original-repo", - type=Path, - default=Path("/Users/sayakpaul/Downloads/ltx-2"), - help="Path to the original ltx-2 repository (needed to import ltx-core).", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Optional path to an original checkpoint containing decoder weights.", - ) parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--output-path", type=Path, required=True) args = parser.parse_args() device = torch.device(args.device) dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(args.original_repo, device, dtype, args.checkpoint) + original_decoder = load_original_decoder(device, dtype) diffusers_model = build_diffusers_decoder(device, dtype) converted_state = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state, strict=False) + diffusers_model.load_state_dict(converted_state, assign=True, strict=True) levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) @@ -95,9 +86,8 @@ def main() -> None: args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype ) - with torch.no_grad(): - original_out = original_decoder(dummy) - diffusers_out = diffusers_model.decode(dummy).sample + original_out = original_decoder(dummy) + diffusers_out = diffusers_model.decode(dummy).sample torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) max_diff = (diffusers_out - original_out).abs().max().item() From 0028955c37e4e1c8c8973a7217ae15f6790bc4ef Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 22 Dec 2025 10:06:01 +0100 Subject: [PATCH 25/86] Initial LTX 2.0 text encoder implementation --- scripts/convert_ltx2_to_diffusers.py | 118 +++- src/diffusers/pipelines/ltx2/text_encoder.py | 625 +++++++++++++++++++ 2 files changed, 742 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/text_encoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index f2e879c06562..8c59b1ea7785 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -7,9 +7,11 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download +from transformers import AutoModel, AutoProcessor from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available +from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder @@ -69,6 +71,15 @@ "conv_post": "conv_out", } +LTX_2_0_TEXT_ENCODER_RENAME_DICT = { + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) @@ -109,6 +120,8 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} +LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {} + def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": @@ -372,6 +385,82 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder +def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "text_encoder_hidden_dim": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + }, + } + rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT + special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."): + model_state_dict = {} + + model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"] + + text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"] + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + new_param_name = param_name.replace(prefix, "") + for submodule_name in text_encoder_submodules: + if new_param_name.startswith(submodule_name): + model_state_dict[new_param_name] = param + break + + return model_state_dict + + +def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version) + diffusers_config = config["diffusers_config"] + diffusers_config["text_model_id"] = text_model_id + diffusers_config["config_only"] = True + + with init_empty_weights(): + text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + base_text_model = AutoModel.from_pretrained(text_model_id) + base_text_model_state_dict= base_text_model.state_dict() + base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()} + combined_state_dict = {**original_state_dict, **base_text_model_state_dict} + + text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True) + return text_encoder + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -458,11 +547,24 @@ def get_args(): parser.add_argument( "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set" ) + parser.add_argument( + "--text_encoder_model_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=str, + help="HF Hub id for the LTX 2.0 base text encoder model", + ) + parser.add_argument( + "--tokenizer_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=str, + help="HF Hub id for the LTX 2.0 text tokenizer", + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") + parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") parser.add_argument( "--full_pipeline", action="store_true", @@ -473,6 +575,7 @@ def get_args(): parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") @@ -497,9 +600,12 @@ def main(args): audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype] dit_dtype = DTYPE_MAPPING[args.dit_dtype] vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype] + text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype] combined_ckpt = None - load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline]) + load_combined_models = any( + [args.vae, args.audio_vae, args.dit, args.vocoder, args.text_encoder, args.full_pipeline] + ) if args.combined_filename is not None and load_combined_models: combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) @@ -533,6 +639,16 @@ def main(args): if not args.full_pipeline: vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) + if args.text_encoder or args.full_pipeline: + text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt) + text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id) + if not args.full_pipeline: + text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) + + tokenizer = AutoProcessor.from_pretrained(args.tokenizer_id) + if not args.full_pipeline: + tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) + if args.full_pipeline: pass diff --git a/src/diffusers/pipelines/ltx2/text_encoder.py b/src/diffusers/pipelines/ltx2/text_encoder.py new file mode 100644 index 000000000000..f15fa62224d2 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/text_encoder.py @@ -0,0 +1,625 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ...models.attention_dispatch import dispatch_attention_fn +from ...models.embeddings import get_1d_rotary_pos_embed +from ...models.modeling_utils import ModelMixin +from ...utils import is_torch_version, logging +from ..pipeline_loading_utils import _fetch_class_library_tuple + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb) + key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + ): + super().__init__() + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + start = 1.0 + end = self.theta + if self.double_precision: + pow_indices = np.power( + self.theta, + np.linspace( + np.log(start) / np.log(self.theta), + np.log(end) / np.log(self.theta), + self.dim // num_rope_elems, + dtype=np.float64, + ), + ) + freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) + else: + freqs = self.theta ** torch.linspace( + start=math.log(start, self.theta), + end=math.log(end, self.theta), + steps=self.dim // num_rope_elems, + device=device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: Optional[int] = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= -9000.0).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin): + ignore_for_config = ["text_model"] + + @register_to_config + def __init__( + self, + text_model: Optional[Gemma3ForConditionalGeneration] = None, + text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", + text_encoder_hidden_dim: Optional[int] = 3840, + text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1 + video_connector_num_attention_heads: int = 30, + video_connector_attention_head_dim: int = 128, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int = 128, + audio_connector_num_attention_heads: int = 30, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: Optional[int] = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_temporal_positioning: bool = False, + config_only: bool = True, + ): + super().__init__() + if text_model is None: + self.set_base_text_encoder(text_model_id, config_only=config_only) + else: + self.base_text_encoder = text_model + + if text_encoder_hidden_dim is None: + if hasattr(self.base_text_encoder, "config"): + if hasattr(self.base_text_encoder.config, "hidden_size"): + text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None) + elif hasattr(self.base_text_encoder.config, "text_config"): + text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None) + if text_encoder_hidden_dim is None: + raise ValueError( + "`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it." + ) + + if text_proj_in_factor is None: + num_layers = None + if hasattr(self.base_text_encoder, "config"): + if hasattr(self.base_text_encoder.config, "num_hidden_layers"): + num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None) + elif hasattr(self.base_text_encoder.config, "text_config"): + num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None) + if num_layers is None: + raise ValueError( + "`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it." + ) + text_proj_in_factor = num_layers + 1 + + self.text_proj_in = nn.Linear( + text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False + ) + + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + + def set_base_text_encoder( + self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True + ): + if config_only: + base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id) + base_text_encoder = AutoModel.from_config(base_text_encoder_config) + else: + base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id) + self.base_text_encoder = base_text_encoder + + @staticmethod + def pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def run_connectors( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states. + + Args: + text_encoder_hidden_states (`torch.Tensor`): + Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`. + attention_mask (`torch.Tensor`): + Attention mask of shape `(batch_size, seq_len)`. + + Returns: + `Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`: + Returns a 3-tuple of tensors where the first element is the video text embeddings of shape + `(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape + `(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape + `(batch_size, seq_len)`. + """ + # Convert to additive attention mask + text_dtype = text_encoder_hidden_states.dtype + connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector( + text_encoder_hidden_states, connector_attn_mask + ) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask + + def forward( + self, + text_input_ids, + attention_mask: Optional[torch.Tensor] = None, + padding_side: str = "left", + scale_factor: int = 8, + ): + text_encoder_outputs = self.base_text_encoder( + input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True + ) + + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = attention_mask.sum(dim=-1) + + text_encoder_hidden_states = self.pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=text_encoder_hidden_states.device, + padding_side=padding_side, + scale_factor=scale_factor, + ) + + video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors( + text_encoder_hidden_states, attention_mask + ) + + return video_text_embedding, audio_text_embedding, new_attn_mask From d0f9cdaab10d58566b0f5eaf2cd2e90af1b94f47 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 22 Dec 2025 10:07:20 +0100 Subject: [PATCH 26/86] Rough initial LTX 2.0 pipeline implementation --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/ltx2/__init__.py | 48 + src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 978 ++++++++++++++++++ .../pipelines/ltx2/pipeline_output.py | 23 + 5 files changed, 1053 insertions(+) create mode 100644 src/diffusers/pipelines/ltx2/__init__.py create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2.py create mode 100644 src/diffusers/pipelines/ltx2/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 71cad3425f0b..8be0e7fc0755 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -536,6 +536,7 @@ "LTXImageToVideoPipeline", "LTXLatentUpsamplePipeline", "LTXPipeline", + "LTX2Pipeline", "LucyEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -1241,6 +1242,7 @@ LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, + LTX2Pipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 388551f812f8..ef9430043bed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -288,6 +288,7 @@ "LTXConditionPipeline", "LTXLatentUpsamplePipeline", ] + _import_structure["ltx2"] = ["LTX2Pipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -719,6 +720,7 @@ LEditsPPPipelineStableDiffusionXL, ) from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline + from .ltx2 import LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..7c1003660fd7 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_ltx2 import LTX2Pipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py new file mode 100644 index 000000000000..9373b21401ef --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -0,0 +1,978 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTX2PipelineOutput +from .text_encoder import LTX2AudioVisualTextEncoder +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Video, + text_encoder: LTX2AudioVisualTextEncoder, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + max_sequence_length: int = 1024, + scale_factor: int = 8, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask.to(device), + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, num_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = num_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 16, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + audio_latent_scale_factor: int = 4, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + duration_s = num_frames / frame_rate + latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_latent_scale_factor) + latent_length = int(duration_s * latents_per_second) + + shape = (batch_size, num_channels_latents, latent_length, num_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 25.0, + num_inference_steps: int = 40, + timesteps: List[int] = None, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `25.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + audio_prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_audio_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=8, # TODO: get from audio VAE + num_mel_bins=16, # TODO: get from audio VAE + num_frames=num_frames, # Video frames, audio frames will be calculated from this + frame_rate=frame_rate, + sampling_rate=self.transformer.config.audio_sampling_rate, + hop_length=self.transformer.config.audio_hop_length, + audio_latent_scale_factor=4, # TODO: get from audio VAE + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=prompt_embeds, + audio_encoder_hidden_states=audio_latent_model_input, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + audio_encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # TODO: we probably can't call step on the same scheduler because it will mess with its internal + # state, how can we get around this? + audio_latents = self.scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # TODO: get num_mel_bins from audio VAE or vocoder? + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=16) + # TODO: apply audio VAE decoder + audio = self.vocoder(audio_latents) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_output.py b/src/diffusers/pipelines/ltx2/pipeline_output.py new file mode 100644 index 000000000000..eacd571125b0 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + frames: torch.Tensor + audio: torch.Tensor From 5f0f2a03f72fc59a606b1d7e03960b5c9a086102 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Dec 2025 10:06:39 +0000 Subject: [PATCH 27/86] up --- scripts/log.txt | 32 ++++++++++ scripts/test_ltx2_audio_conversion.py | 59 +++++++++---------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 20 +++++-- 3 files changed, 75 insertions(+), 36 deletions(-) create mode 100644 scripts/log.txt diff --git a/scripts/log.txt b/scripts/log.txt new file mode 100644 index 000000000000..aa3046d42abd --- /dev/null +++ b/scripts/log.txt @@ -0,0 +1,32 @@ +ddconfig={'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, 'norm_type': 'pixel', 'causality_axis': 'height'}, sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 +mid_block_add_attention=False, attn_resolutions=[] +k='mid.block_1.conv1.conv.weight' +k='mid.block_1.conv1.conv.bias' +k='mid.block_1.conv2.conv.weight' +k='mid.block_1.conv2.conv.bias' +k='mid.block_2.conv1.conv.weight' +k='mid.block_2.conv1.conv.bias' +k='mid.block_2.conv2.conv.weight' +k='mid.block_2.conv2.conv.bias' +Traceback (most recent call last): + File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 97, in + main() + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 85, in main + original_out = original_decoder(dummy) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py", line 206, in forward + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/ops.py", line 27, in un_normalize + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2 diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index f9554782c9f3..6a124f74df0d 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -1,7 +1,6 @@ import argparse from pathlib import Path -import safetensors.torch import torch from huggingface_hub import hf_hub_download @@ -9,11 +8,9 @@ def download_checkpoint( repo_id="diffusers-internal-dev/new-ltx-model", filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", - device="cuda", ): ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) - ckpt = safetensors.torch.load_file(ckpt_path, device=device)["audio_vae"] - return ckpt + return ckpt_path def convert_state_dict(state_dict: dict) -> dict: @@ -28,34 +25,33 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(device: torch.device, dtype: torch.dtype): - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator - - with torch.device("meta"): - decoder = VAEDecoderConfigurator.from_config({}).to(device=device, dtype=dtype) - original_state_dict = download_checkpoint(device) - - decoder_state_dict = {} - for key, value in original_state_dict.items(): - if not isinstance(value, torch.Tensor): - continue - trimmed = key - if trimmed.startswith("audio_vae.decoder."): - trimmed = trimmed[len("audio_vae.decoder.") :] - elif trimmed.startswith("decoder."): - trimmed = trimmed[len("decoder.") :] - decoder_state_dict[trimmed] = value - decoder.load_state_dict(decoder_state_dict, strict=True, assign=True) - +def load_original_decoder(device: torch.device): + from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER + + checkpoint_path = download_checkpoint() + + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` + decoder = Builder( + model_path=checkpoint_path, + model_class_configurator=AudioDecoderConfigurator, + model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + ).build(device=device) + + state_dict = decoder.state_dict() + for k, v in state_dict.items(): + if "mid" in k: + print(f"{k=}") decoder.eval() return decoder -def build_diffusers_decoder(device: torch.device, dtype: torch.dtype): +def build_diffusers_decoder(): from diffusers.models.autoencoders import AutoencoderKLLTX2Audio with torch.device("meta"): - model = AutoencoderKLLTX2Audio().to(device=device, dtype=dtype) + model = AutoencoderKLLTX2Audio() model.eval() return model @@ -74,16 +70,16 @@ def main() -> None: dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(device, dtype) - diffusers_model = build_diffusers_decoder(device, dtype) + original_decoder = load_original_decoder(device) + diffusers_model = build_diffusers_decoder() - converted_state = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state, assign=True, strict=True) + converted_state_dict = convert_state_dict(original_decoder.state_dict()) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True) levels = len(diffusers_model.decoder.channel_multipliers) latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) dummy = torch.randn( - args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device, dtype=dtype + args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device ) original_out = original_decoder(dummy) @@ -93,6 +89,9 @@ def main() -> None: max_diff = (diffusers_out - original_out).abs().max().item() print(f"Conversion successful. Max diff: {max_diff:.6f}") + diffusers_model.to(dtype).save_pretrained(args.output_path) + print(f"Serialized model to {args.output_path}") + if __name__ == "__main__": main() diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 98d8a53e2359..457cbf5bce12 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -533,8 +533,9 @@ def _build_up_path( ) ) block_in = block_out - if curr_res in self.attn_resolutions: - stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) if level != 0: stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) @@ -579,6 +580,13 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): _supports_gradient_checkpointing = False + # { + # 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, + # 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, + # 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, + # 'norm_type': 'pixel', 'causality_axis': 'height' + # } + # sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 @register_to_config def __init__( self, @@ -586,23 +594,23 @@ def __init__( output_channels: int = 2, ch_mult: Tuple[int] = (1, 2, 4), num_res_blocks: int = 2, - attn_resolutions: Tuple[int] = (8, 16, 32), + attn_resolutions: Optional[Tuple[int]] = None, in_channels: int = 2, resolution: int = 256, latent_channels: int = 8, norm_type: str = "pixel", causality_axis: Optional[str] = "height", dropout: float = 0.0, - mid_block_add_attention: bool = True, + mid_block_add_attention: bool = False, sample_rate: int = 16000, mel_hop_length: int = 160, is_causal: bool = True, - mel_bins: Optional[int] = None, + mel_bins: Optional[int] = 64, ) -> None: super().__init__() resolved_causality_axis = _resolve_causality_axis(causality_axis) - attn_resolution_set = set(attn_resolutions) + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions self.decoder = LTX2AudioDecoder( base_channels=base_channels, From 58257eb0e0f1a8ac07ff4854009f35c1b2bad444 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 15:45:56 +0530 Subject: [PATCH 28/86] up --- scripts/test_ltx2_audio_conversion.py | 31 ++++++++++++------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 7 ----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 6a124f74df0d..8d07a6f9b1fe 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -25,13 +25,13 @@ def convert_state_dict(state_dict: dict) -> dict: return converted -def load_original_decoder(device: torch.device): +def load_original_decoder(device: torch.device, dtype: torch.dtype): from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER - + from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator + checkpoint_path = download_checkpoint() - + # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` decoder = Builder( model_path=checkpoint_path, @@ -39,10 +39,6 @@ def load_original_decoder(device: torch.device): model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, ).build(device=device) - state_dict = decoder.state_dict() - for k, v in state_dict.items(): - if "mid" in k: - print(f"{k=}") decoder.eval() return decoder @@ -70,16 +66,27 @@ def main() -> None: dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} dtype = dtype_map[args.dtype] - original_decoder = load_original_decoder(device) + original_decoder = load_original_decoder(device, dtype) diffusers_model = build_diffusers_decoder() converted_state_dict = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=True) + diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False) + + per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel() + latent_channels = diffusers_model.decoder.latent_channels + mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None levels = len(diffusers_model.decoder.channel_multipliers) - latent_size = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1)) + latent_width = mel_bins_for_match or latent_height + dummy = torch.randn( - args.batch, diffusers_model.decoder.latent_channels, latent_size, latent_size, device=device + args.batch, + diffusers_model.decoder.latent_channels, + latent_height, + latent_width, + device=device, + dtype=dtype, ) original_out = original_decoder(dummy) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 457cbf5bce12..e7960c3e14bf 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -580,13 +580,6 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): _supports_gradient_checkpointing = False - # { - # 'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, - # 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, - # 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, - # 'norm_type': 'pixel', 'causality_axis': 'height' - # } - # sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 @register_to_config def __init__( self, From 059999a3f7ad3fe3077f61812e3b3de91136f4bb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Dec 2025 10:24:55 +0000 Subject: [PATCH 29/86] up --- scripts/log.txt | 32 ------------------- .../autoencoders/autoencoder_kl_ltx2_audio.py | 22 +++++++------ 2 files changed, 12 insertions(+), 42 deletions(-) delete mode 100644 scripts/log.txt diff --git a/scripts/log.txt b/scripts/log.txt deleted file mode 100644 index aa3046d42abd..000000000000 --- a/scripts/log.txt +++ /dev/null @@ -1,32 +0,0 @@ -ddconfig={'double_z': True, 'mel_bins': 64, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, 'in_channels': 2, 'out_ch': 2, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0, 'mid_block_add_attention': False, 'norm_type': 'pixel', 'causality_axis': 'height'}, sample_rate=16000, mel_hop_length=160, is_causal=True, mel_bins=64 -mid_block_add_attention=False, attn_resolutions=[] -k='mid.block_1.conv1.conv.weight' -k='mid.block_1.conv1.conv.bias' -k='mid.block_1.conv2.conv.weight' -k='mid.block_1.conv2.conv.bias' -k='mid.block_2.conv1.conv.weight' -k='mid.block_2.conv1.conv.bias' -k='mid.block_2.conv2.conv.weight' -k='mid.block_2.conv2.conv.bias' -Traceback (most recent call last): - File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 97, in - main() - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context - return func(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/diffusers-new-model-addition-ltx2/scripts/test_ltx2_audio_conversion.py", line 85, in main - original_out = original_decoder(dummy) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py", line 206, in forward - sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/fsx/sayak/ltx-2/ltx-core/src/ltx_core/model/audio_vae/ops.py", line 27, in un_normalize - return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) - ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2 diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index e7960c3e14bf..1385b414b975 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -99,8 +99,9 @@ def __init__( super().__init__() self.causality_axis = causality_axis - kernel_size = nn.modules.utils._pair(kernel_size) - dilation = nn.modules.utils._pair(dilation) + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + pad_h = (kernel_size[0] - 1) * dilation[0] pad_w = (kernel_size[1] - 1) * dilation[1] @@ -232,7 +233,7 @@ def __init__( def forward( self, x: torch.Tensor, - temb: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None ) -> torch.Tensor: h = self.norm1(x) h = self.non_linearity(h) @@ -257,7 +258,7 @@ def __init__( self, in_channels: int, with_conv: bool, - causality_axis: Optional[str] = "height", + causality_axis: Optional[str] = "height" ) -> None: super().__init__() self.with_conv = with_conv @@ -291,10 +292,11 @@ class LTX2AudioPerChannelStatistics(nn.Module): def __init__(self, latent_channels: int = 128) -> None: super().__init__() + # Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`? self.register_buffer("std-of-means", torch.empty(latent_channels)) self.register_buffer("mean-of-means", torch.empty(latent_channels)) - def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + def denormalize(self, x: torch.Tensor) -> torch.Tensor: return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) def normalize(self, x: torch.Tensor) -> torch.Tensor: @@ -327,7 +329,7 @@ def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: def unpatchify( self, audio_latents: torch.Tensor, - output_shape: AudioLatentShape, + output_shape: AudioLatentShape ) -> torch.Tensor: batch, time, _ = audio_latents.shape channels = output_shape.channels @@ -421,7 +423,7 @@ def __init__( def _adjust_output_shape( self, decoded_output: torch.Tensor, - target_shape: AudioLatentShape, + target_shape: AudioLatentShape ) -> torch.Tensor: _, _, current_time, current_freq = decoded_output.shape target_channels = target_shape.channels @@ -460,7 +462,7 @@ def forward( ) sample_patched = self.patchifier.patchify(sample) - sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample_denormalized = self.per_channel_statistics.denormalize(sample_patched) sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR @@ -509,7 +511,7 @@ def _build_up_path( self, initial_block_channels: int, dropout: float, - resamp_with_conv: bool, + resamp_with_conv: bool ) -> tuple[nn.ModuleList, int]: up_modules = nn.ModuleList() block_in = initial_block_channels @@ -630,7 +632,7 @@ def __init__( def encode( self, x: torch.Tensor, - return_dict: bool = True, + return_dict: bool = True ): raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") From 8134da6a56d2fe3fde82af00f079b0615d9768e8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Dec 2025 15:55:29 +0530 Subject: [PATCH 30/86] up --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 37 +++---------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 1385b414b975..e3c0ef2c3ddc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -102,7 +102,6 @@ def __init__( kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size dilation = (dilation, dilation) if isinstance(dilation, int) else dilation - pad_h = (kernel_size[0] - 1) * dilation[0] pad_w = (kernel_size[1] - 1) * dilation[1] @@ -230,11 +229,7 @@ def __init__( in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis ) - def forward( - self, - x: torch.Tensor, - temb: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: h = self.norm1(x) h = self.non_linearity(h) h = self.conv1(h) @@ -254,12 +249,7 @@ def forward( class LTX2AudioUpsample(nn.Module): - def __init__( - self, - in_channels: int, - with_conv: bool, - causality_axis: Optional[str] = "height" - ) -> None: + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: super().__init__() self.with_conv = with_conv self.causality_axis = causality_axis @@ -326,11 +316,7 @@ def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: batch, channels, time, freq = audio_latents.shape return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) - def unpatchify( - self, - audio_latents: torch.Tensor, - output_shape: AudioLatentShape - ) -> torch.Tensor: + def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor: batch, time, _ = audio_latents.shape channels = output_shape.channels freq = output_shape.mel_bins @@ -420,11 +406,7 @@ def __init__( final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis ) - def _adjust_output_shape( - self, - decoded_output: torch.Tensor, - target_shape: AudioLatentShape - ) -> torch.Tensor: + def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor: _, _, current_time, current_freq = decoded_output.shape target_channels = target_shape.channels target_time = target_shape.frames @@ -508,10 +490,7 @@ def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) return mid def _build_up_path( - self, - initial_block_channels: int, - dropout: float, - resamp_with_conv: bool + self, initial_block_channels: int, dropout: float, resamp_with_conv: bool ) -> tuple[nn.ModuleList, int]: up_modules = nn.ModuleList() block_in = initial_block_channels @@ -629,11 +608,7 @@ def __init__( self.use_slicing = False @apply_forward_hook - def encode( - self, - x: torch.Tensor, - return_dict: bool = True - ): + def encode(self, x: torch.Tensor, return_dict: bool = True): raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") def _decode(self, z: torch.Tensor) -> torch.Tensor: From 5f7e43d17fe6edf60fe4dcd8b0d8320e84a259ac Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 02:08:51 +0100 Subject: [PATCH 31/86] Add imports for LTX 2.0 Audio VAE --- src/diffusers/__init__.py | 2 ++ src/diffusers/models/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 71cad3425f0b..8c6761a07e3e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -194,6 +194,7 @@ "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", + "AutoencoderKLLTX2Audio", "AutoencoderKLLTX2Video", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -929,6 +930,7 @@ AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3f4e49015b59..d3bcb3bcee7a 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -42,6 +42,7 @@ _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] + _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -154,6 +155,7 @@ AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, + AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, AutoencoderKLMagvit, AutoencoderKLMochi, From d303e2a6ff841919531facf302fd0e724ae57d33 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 02:48:08 +0100 Subject: [PATCH 32/86] Conversion script for LTX 2.0 Audio VAE Decoder --- scripts/convert_ltx2_to_diffusers.py | 80 +++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index f2e879c06562..eb130a354945 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,7 +8,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder @@ -62,6 +62,8 @@ "per_channel_statistics.std-of-means": "latents_std", } +LTX_2_0_AUDIO_VAE_RENAME_DICT = {} + LTX_2_0_VOCODER_RENAME_DICT = { "ups": "upsamplers", "resblocks": "resnets", @@ -96,6 +98,15 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) return +def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "video_embeddings_connector": remove_keys_inplace, "audio_embeddings_connector": remove_keys_inplace, @@ -107,6 +118,11 @@ def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) "per_channel_statistics.mean-of-stds": remove_keys_inplace, } +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { + "encoder": remove_keys_inplace, + "per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics, +} + LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} @@ -325,6 +341,60 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> return vae +def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "2.0": config = { @@ -513,7 +583,13 @@ def main(args): vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: - pass + if args.audio_vae_filename is not None: + original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename) + elif combined_ckpt is not None: + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version) + if not args.full_pipeline: + audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae")) if args.dit or args.full_pipeline: if args.dit_filename is not None: From 54bfc5d6178fec3e2a52e86f73b9f2aaf0e68927 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 03:51:22 +0100 Subject: [PATCH 33/86] Add Audio VAE logic to T2V pipeline --- .../autoencoders/autoencoder_kl_ltx2_audio.py | 4 ++ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 67 +++++++++++++------ 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index e3c0ef2c3ddc..90ddf2aa6e6b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -605,6 +605,10 @@ def __init__( mel_bins=mel_bins, ) + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR self.use_slicing = False @apply_forward_hook diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 9373b21401ef..99160a38be6c 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -21,7 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -201,7 +201,7 @@ def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, - audio_vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, text_encoder: LTX2AudioVisualTextEncoder, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], transformer: LTX2VideoTransformer3DModel, @@ -225,6 +225,13 @@ def __init__( self.vae_temporal_compression_ratio = ( self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) self.transformer_spatial_patch_size = ( self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 ) @@ -232,6 +239,13 @@ def __init__( self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 ) + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) self.tokenizer_max_length = ( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 @@ -487,9 +501,9 @@ def _pack_audio_latents( if patch_size is not None and patch_size_t is not None: # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. - batch_size, num_channels, latent_length, num_mel_bins = latents.shape + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape post_patch_latent_length = latent_length / patch_size_t - post_patch_mel_bins = num_mel_bins / patch_size + post_patch_mel_bins = latent_mel_bins / patch_size latents = latents.reshape( batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size ) @@ -556,12 +570,11 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, - num_mel_bins: int = 16, + num_mel_bins: int = 64, num_frames: int = 121, frame_rate: float = 25.0, sampling_rate: int = 16000, hop_length: int = 160, - audio_latent_scale_factor: int = 4, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, @@ -571,10 +584,13 @@ def prepare_audio_latents( return latents.to(device=device, dtype=dtype) duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(audio_latent_scale_factor) + latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) latent_length = int(duration_s * latents_per_second) - shape = (batch_size, num_channels_latents, latent_length, num_mel_bins) + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -792,6 +808,11 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -805,15 +826,20 @@ def __call__( latents, ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) audio_latents, audio_num_frames = self.prepare_audio_latents( batch_size * num_videos_per_prompt, - num_channels_latents=8, # TODO: get from audio VAE - num_mel_bins=16, # TODO: get from audio VAE + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, num_frames=num_frames, # Video frames, audio frames will be calculated from this frame_rate=frame_rate, - sampling_rate=self.transformer.config.audio_sampling_rate, - hop_length=self.transformer.config.audio_hop_length, - audio_latent_scale_factor=4, # TODO: get from audio VAE + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -821,10 +847,6 @@ def __call__( ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, @@ -964,10 +986,11 @@ def __call__( video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) - # TODO: get num_mel_bins from audio VAE or vocoder? - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=16) - # TODO: apply audio VAE decoder - audio = self.vocoder(audio_latents) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's + # decode method + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + waveforms = self.vocoder(generated_mel_spectrograms) # Offload all models self.maybe_free_model_hooks() @@ -975,4 +998,4 @@ def __call__( if not return_dict: return (video, audio) - return LTX2PipelineOutput(frames=video, audio=audio) + return LTX2PipelineOutput(frames=video, audio=waveforms) From 6e6ce2059502d5bfb1f89e0dcbb10c6d86ff4f0f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 06:40:35 +0100 Subject: [PATCH 34/86] Duplicate scheduler for audio latents --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 99160a38be6c..250ff7284f4f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -865,6 +866,16 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) # 6. Prepare micro-conditions rope_interpolation_scale = ( @@ -928,9 +939,9 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] - # TODO: we probably can't call step on the same scheduler because it will mess with its internal - # state, how can we get around this? - audio_latents = self.scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} From cbb10b8dcae1ea9588fbb31aaadb7c60d3bba27f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 07:01:17 +0100 Subject: [PATCH 35/86] Support num_videos_per_prompt for prompt embeddings --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 250ff7284f4f..af9b0096fd46 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -255,10 +255,11 @@ def __init__( def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], - device: torch.device, - dtype: torch.dtype, + num_videos_per_prompt: int = 1, max_sequence_length: int = 1024, scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -272,7 +273,11 @@ def _get_gemma_prompt_embeds( torch dtype to cast the prompt embeds to max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.base_text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if getattr(self, "tokenizer", None) is not None: # Gemma expects left padding for chat-style prompts @@ -301,6 +306,18 @@ def _get_gemma_prompt_embeds( prompt_embeds = prompt_embeds.to(dtype=dtype) audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + _, audio_seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask def encode_prompt( @@ -310,10 +327,13 @@ def encode_prompt( do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - max_sequence_length: int = 128, + max_sequence_length: int = 1024, + scale_factor: int = 8, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -356,6 +376,7 @@ def encode_prompt( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, + scale_factor=scale_factor, device=device, dtype=dtype, ) @@ -380,6 +401,7 @@ def encode_prompt( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, + scale_factor=scale_factor, device=device, dtype=dtype, ) @@ -650,8 +672,10 @@ def __call__( latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -712,11 +736,17 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + audio_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -797,7 +827,9 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, + audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, From 595f485ad8e1449eeae29639bd5e09b3887eb4f0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 07:41:28 +0100 Subject: [PATCH 36/86] LTX 2.0 scheduler and full pipeline conversion --- scripts/convert_ltx2_to_diffusers.py | 35 ++++++++++++++++--- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 6 ++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 78494a52b9f8..6c4cac739632 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -7,9 +7,15 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoModel, AutoProcessor - -from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel +from transformers import AutoModel, AutoTokenizer + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) from diffusers.utils.import_utils import is_accelerate_available from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder @@ -721,12 +727,31 @@ def main(args): if not args.full_pipeline: text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) - tokenizer = AutoProcessor.from_pretrained(args.tokenizer_id) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) if args.full_pipeline: - pass + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) + + pipe = LTX2Pipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vocoder=vocoder, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if __name__ == '__main__': diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index af9b0096fd46..eff87c08a320 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -883,10 +883,10 @@ def __call__( sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, - self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, From 3bf736979fce8c8753d063018ca0dc0787d56aed Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 08:43:37 +0100 Subject: [PATCH 37/86] Add script to test full LTX2Pipeline T2V inference --- scripts/ltx2_test_full_pipeline.py | 213 +++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 scripts/ltx2_test_full_pipeline.py diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py new file mode 100644 index 000000000000..019bbda46d7b --- /dev/null +++ b/scripts/ltx2_test_full_pipeline.py @@ -0,0 +1,213 @@ +import argparse +import os +from fractions import Fraction +from typing import Optional + +import av # Needs to be installed separately (`pip install av`) +import torch + +from diffusers import LTX2Pipeline + + +# Video export functions copied from original LTX 2.0 code +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str +) -> None: + video_np = video.cpu().numpy() + + _, height, width, _ = video_np.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") + parser.add_argument("--revision", type=str, default="main") + + parser.add_argument( + "--prompt", + type=str, + default="A video of a dog dancing to energetic electronic dance music", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." + ), + ) + + parser.add_argument("--num_inference_steps", type=int, default=40) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=768) + parser.add_argument("--num_frames", type=int, default=121) + parser.add_argument("--frame_rate", type=float, default=25.0) + parser.add_argument("--guidance_scale", type=float, default=3.0) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--dtype", type=str, default="bf16") + + parser.add_argument( + "--output_dir", + type=str, + default="/home/daniel_gu/samples", + help="Output directory for generated video", + ) + parser.add_argument( + "--output_filename", + type=str, + default="ltx2_sample_video.mp4", + help="Filename of the exported generated video", + ) + + args = parser.parse_args() + args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + return args + + +def main(args): + pipeline = LTX2Pipeline.from_pretrained( + args.model_id, + revision=args.revision, + torch_dtype=args.dtype, + ) + pipeline.to(device=args.device) + + video, audio = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device=args.device).manual_seed(args.seed), + output_type="np", + ) + + # Convert video to uint8 (but keep as NumPy array) + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + # video should already be frames first, reshape to channels-last (we want shape to be (*, F, H , W, C)) + video = video.permute(0, 1, 3, 4, 2) + + encode_video( + video[0], + fps=args.frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000 + output_path=os.path.join(args.output_dir, args.output_filename), + ) + + +if __name__ == '__main__': + args = parse_args() + main(args) From fa7d9f77f143c20b8049f630fd20d3fe75fe1342 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 08:49:11 +0100 Subject: [PATCH 38/86] Fix pipeline return bugs --- scripts/ltx2_test_full_pipeline.py | 1 + src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 019bbda46d7b..1907be2da8e4 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -191,6 +191,7 @@ def main(args): guidance_scale=args.guidance_scale, generator=torch.Generator(device=args.device).manual_seed(args.seed), output_type="np", + return_dict=False, ) # Convert video to uint8 (but keep as NumPy array) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index eff87c08a320..e8a41050f5d3 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1033,7 +1033,7 @@ def __call__( # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's # decode method generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] - waveforms = self.vocoder(generated_mel_spectrograms) + audio = self.vocoder(generated_mel_spectrograms) # Offload all models self.maybe_free_model_hooks() @@ -1041,4 +1041,4 @@ def __call__( if not return_dict: return (video, audio) - return LTX2PipelineOutput(frames=video, audio=waveforms) + return LTX2PipelineOutput(frames=video, audio=audio) From a56cf23483d097e6d0d4b40f8259294e6074a880 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 10:40:56 +0100 Subject: [PATCH 39/86] Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__ --- src/diffusers/pipelines/ltx2/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 7c1003660fd7..d23123089fb8 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,6 +23,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"] + _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +35,8 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx2 import LTX2Pipeline + from .text_encoder import LTX2AudioVisualTextEncoder + from .vocoder import LTX2Vocoder else: import sys From 90edc6abc94d9e22e19b91c573ea97a9da05c68f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 10:41:27 +0100 Subject: [PATCH 40/86] Fix more bugs in LTX2Pipeline.__call__ --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index e8a41050f5d3..5aed290aa0cf 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -888,26 +888,26 @@ def __call__( self.scheduler.config.get("base_shift", 0.95), self.scheduler.config.get("max_shift", 2.05), ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu, ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - # For now, duplicate the scheduler for use with the audio latents - audio_scheduler = copy.deepcopy(self.scheduler) - _, _ = retrieve_timesteps( - audio_scheduler, + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu, ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions rope_interpolation_scale = ( @@ -937,7 +937,7 @@ def __call__( hidden_states=latent_model_input, audio_hidden_states=audio_latent_model_input, encoder_hidden_states=prompt_embeds, - audio_encoder_hidden_states=audio_latent_model_input, + audio_encoder_hidden_states=audio_prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, audio_encoder_attention_mask=prompt_attention_mask, From 1484c43183a18488c73618abe3e64354deb1acbd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 10:56:32 +0100 Subject: [PATCH 41/86] Improve CPU offload support --- scripts/ltx2_test_full_pipeline.py | 3 +++ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 1907be2da8e4..63c6f0400f6d 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -153,6 +153,7 @@ def parse_args(): parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--dtype", type=str, default="bf16") + parser.add_argument("--cpu_offload", action="store_true") parser.add_argument( "--output_dir", @@ -179,6 +180,8 @@ def main(args): torch_dtype=args.dtype, ) pipeline.to(device=args.device) + if args.cpu_offload: + pipeline.enable_model_cpu_offload() video, audio = pipeline( prompt=args.prompt, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 5aed290aa0cf..45cfc8e3cd99 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -194,7 +194,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] From f9b947651f06c2658f56834d7eef833631551079 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 11:03:19 +0100 Subject: [PATCH 42/86] Fix pipeline audio VAE decoding dtype bug --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 45cfc8e3cd99..258e597d71f7 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1030,6 +1030,7 @@ def __call__( video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + audio_latents = audio_latents.to(self.audio_vae.dtype) # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's # decode method generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] From e89d9c1951b184ee58efda0f38b2d8c94976c895 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 11:14:05 +0100 Subject: [PATCH 43/86] Fix video shape error in full pipeline test script --- scripts/ltx2_test_full_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 63c6f0400f6d..14b02a490fa7 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -200,8 +200,6 @@ def main(args): # Convert video to uint8 (but keep as NumPy array) video = (video * 255).round().astype("uint8") video = torch.from_numpy(video) - # video should already be frames first, reshape to channels-last (we want shape to be (*, F, H , W, C)) - video = video.permute(0, 1, 3, 4, 2) encode_video( video[0], From b5891b19b195b6f711e75dcc248798de17edf6c4 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 24 Dec 2025 06:07:38 +0100 Subject: [PATCH 44/86] Get LTX 2 T2V pipeline to produce reasonable outputs --- scripts/convert_ltx2_to_diffusers.py | 6 +++ .../models/transformers/transformer_ltx2.py | 44 +++++++++++++------ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 15 +++++-- .../test_models_transformer_ltx2.py | 4 +- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 6c4cac739632..479a569817c2 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -182,7 +182,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "attention_bias": True, "attention_out_bias": True, "rope_theta": 10000.0, + "rope_double_precision": False, "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1, }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT @@ -222,7 +225,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "attention_bias": True, "attention_out_bias": True, "rope_theta": 10000.0, + "rope_double_precision": True, "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1, }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index ea9bca115e99..3d2d079608ba 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -18,6 +18,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn @@ -561,6 +562,7 @@ def __init__( theta: float = 10000.0, causal_offset: int = 1, modality: str = "video", + double_precision: bool = True, ) -> None: super().__init__() @@ -586,6 +588,7 @@ def __init__( self.modality = modality if self.modality not in ["video", "audio"]: raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision def prepare_video_coords( self, @@ -779,14 +782,26 @@ def forward( # 4. Create a 1D grid of frequencies for RoPE start = 1.0 end = self.theta - freqs = self.theta ** torch.linspace( - start=math.log(start, self.theta), - end=math.log(end, self.theta), - steps=self.dim // num_rope_elems, - device=device, - dtype=torch.float32, - ) - freqs = freqs * math.pi / 2.0 + if self.double_precision: + pow_indices = np.power( + self.theta, + np.linspace( + np.log(start) / np.log(self.theta), + np.log(end) / np.log(self.theta), + self.dim // num_rope_elems, + dtype=np.float64, + ), + ) + freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) + else: + freqs = self.theta ** torch.linspace( + start=math.log(start, self.theta), + end=math.log(end, self.theta), + steps=self.dim // num_rope_elems, + device=device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape # (self.dim // num_elems,) @@ -885,7 +900,10 @@ def __init__( attention_bias: bool = True, attention_out_bias: bool = True, rope_theta: float = 10000.0, + rope_double_precision: bool = True, causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1, ) -> None: super().__init__() @@ -951,6 +969,7 @@ def __init__( theta=rope_theta, causal_offset=causal_offset, modality="video", + double_precision=rope_double_precision, ) self.audio_rope = LTX2AudioVideoRotaryPosEmbed( dim=audio_inner_dim, @@ -963,6 +982,7 @@ def __init__( theta=rope_theta, causal_offset=causal_offset, modality="audio", + double_precision=rope_double_precision, ) # Audio-to-Video, Video-to-Audio Cross-Attention @@ -977,6 +997,7 @@ def __init__( theta=rope_theta, causal_offset=causal_offset, modality="video", + double_precision=rope_double_precision, ) self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( dim=audio_cross_attention_dim, @@ -988,6 +1009,7 @@ def __init__( theta=rope_theta, causal_offset=causal_offset, modality="audio", + double_precision=rope_double_precision, ) # 5. Transformer Blocks @@ -1038,8 +1060,6 @@ def forward( audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, audio_coords: Optional[torch.Tensor] = None, - timestep_scale_multiplier: int = 1000, - cross_attn_timestep_scale_multiplier: int = 1, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -1109,9 +1129,7 @@ def forward( audio_hidden_states = self.audio_proj_in(audio_hidden_states) # 3. Prepare timestep embeddings and modulation parameters - # Scale timestep - timestep = timestep * timestep_scale_multiplier - timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier + timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 258e597d71f7..a4ee5cb150a5 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -603,13 +603,13 @@ def prepare_audio_latents( generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) - duration_s = num_frames / frame_rate latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) latent_length = int(duration_s * latents_per_second) + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio @@ -915,6 +915,13 @@ def __call__( self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device, fps=frame_rate + ) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -946,6 +953,8 @@ def __call__( width=latent_width, fps=frame_rate, audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, # rope_interpolation_scale=rope_interpolation_scale, attention_kwargs=attention_kwargs, return_dict=False, diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 6c0b97c58906..079273e975e5 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -58,7 +58,7 @@ def dummy_input(self): encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.rand((batch_size,)).to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) * 1000 return { "hidden_states": hidden_states, @@ -121,7 +121,7 @@ def test_ltx2_consistency(self, seed=0, dtype=torch.float32): sampling_rate = 16000.0 hop_length = 160.0 - sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") + sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) num_channels = 4 From 581f21c43120ce57020bd60fc1a86576adb0a4ef Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 29 Dec 2025 23:44:52 +0100 Subject: [PATCH 45/86] Make LTX 2.0 scheduler more consistent with original code --- scripts/convert_ltx2_to_diffusers.py | 4 +++- scripts/ltx2_test_full_pipeline.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 479a569817c2..c7f066747032 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -1,4 +1,5 @@ import argparse +import math import os from contextlib import nullcontext from typing import Any, Dict, Optional, Tuple @@ -739,7 +740,8 @@ def main(args): if args.full_pipeline: scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, + use_dynamic_shifting=False, + shift=math.exp(2.05), # Equivalent to dynamic shift if always using max_image_seq_len base_shift=0.95, max_shift=2.05, base_image_seq_len=1024, diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 14b02a490fa7..6d7317ea0390 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -1,4 +1,5 @@ import argparse +import math import os from fractions import Fraction from typing import Optional @@ -6,7 +7,7 @@ import av # Needs to be installed separately (`pip install av`) import torch -from diffusers import LTX2Pipeline +from diffusers import LTX2Pipeline, FlowMatchEulerDiscreteScheduler # Video export functions copied from original LTX 2.0 code @@ -150,6 +151,7 @@ def parse_args(): parser.add_argument("--frame_rate", type=float, default=25.0) parser.add_argument("--guidance_scale", type=float, default=3.0) parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--apply_scheduler_fix", action="store_true") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--dtype", type=str, default="bf16") @@ -179,6 +181,15 @@ def main(args): revision=args.revision, torch_dtype=args.dtype, ) + if args.apply_scheduler_fix: + max_shift = pipeline.scheduler.config.max_shift + time_shift_type = pipeline.scheduler.config.time_shift_type + fixed_scheduler = FlowMatchEulerDiscreteScheduler.from_config( + pipeline.scheduler.config, + dynamic_shifting=False, + shift=math.exp(max_shift) if time_shift_type == "exponential" else max_shift, + ) + pipeline.scheduler = fixed_scheduler pipeline.to(device=args.device) if args.cpu_offload: pipeline.enable_model_cpu_offload() From e1f0b7e255922b4bc0c6f04e58fca1bef23ebba8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 30 Dec 2025 00:38:51 +0100 Subject: [PATCH 46/86] Fix typo when applying scheduler fix in T2V inference script --- scripts/ltx2_test_full_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 6d7317ea0390..37a649d5ea74 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -186,7 +186,7 @@ def main(args): time_shift_type = pipeline.scheduler.config.time_shift_type fixed_scheduler = FlowMatchEulerDiscreteScheduler.from_config( pipeline.scheduler.config, - dynamic_shifting=False, + use_dynamic_shifting=False, shift=math.exp(max_shift) if time_shift_type == "exponential" else max_shift, ) pipeline.scheduler = fixed_scheduler From 280e34781457a774652fab1c0400c431a73d6546 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Dec 2025 08:05:56 +0530 Subject: [PATCH 47/86] Refactor Audio VAE to be simpler and remove helpers (#7) * remove resolve causality axes stuff. * remove a bunch of helpers. * remove adjust output shape helper. * remove the use of audiolatentshape. * move normalization and patchify out of pipeline. * fix * up * up * Remove unpatchify and patchify ops before audio latents denormalization (#9) --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- scripts/convert_ltx2_to_diffusers.py | 5 +- scripts/test_ltx2_audio_conversion.py | 15 + .../autoencoders/autoencoder_kl_ltx2_audio.py | 362 ++++++++---------- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 13 +- 4 files changed, 179 insertions(+), 216 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 479a569817c2..d1384c1dcacd 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -70,7 +70,10 @@ "per_channel_statistics.std-of-means": "latents_std", } -LTX_2_0_AUDIO_VAE_RENAME_DICT = {} +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} LTX_2_0_VOCODER_RENAME_DICT = { "ups": "upsamplers", diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index 8d07a6f9b1fe..a6ba16ed9efa 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -22,6 +22,9 @@ def convert_state_dict(state_dict: dict) -> dict: if new_key.startswith("decoder."): new_key = new_key[len("decoder.") :] converted[f"decoder.{new_key}"] = value + + converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means") + converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means") return converted @@ -87,9 +90,21 @@ def main() -> None: latent_width, device=device, dtype=dtype, + generator=torch.Generator(device).manual_seed(42) ) original_out = original_decoder(dummy) + + from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline + + _, a_channels, a_time, a_freq = dummy.shape + dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq) + dummy = LTX2Pipeline._denormalize_audio_latents( + dummy, + diffusers_model.latents_mean, + diffusers_model.latents_std, + ) + dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3) diffusers_out = diffusers_model.decode(dummy).sample torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 90ddf2aa6e6b..8cdcfa1a74c5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Optional, Set, Tuple, Union import torch @@ -27,57 +26,6 @@ LATENT_DOWNSAMPLE_FACTOR = 4 -SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"} - - -AudioLatentShape = namedtuple( - "AudioLatentShape", - [ - "batch", - "channels", - "frames", - "mel_bins", - ], -) - - -def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]: - normalized = "none" if causality_axis is None else str(causality_axis).lower() - if normalized not in SUPPORTED_CAUSAL_AXES: - raise NotImplementedError( - f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}" - ) - return None if normalized == "none" else normalized - - -def make_conv2d( - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: int = 1, - padding: Optional[Tuple[int, int, int, int]] = None, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - causality_axis: Optional[str] = None, -) -> nn.Module: - if causality_axis is not None: - return LTX2AudioCausalConv2d( - in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis - ) - if padding is None: - padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) - - return nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - ) class LTX2AudioCausalConv2d(nn.Module): @@ -147,14 +95,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x / rms -def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module: - if normtype == "group": - return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if normtype == "pixel": - return LTX2AudioPixelNorm(dim=1, eps=1e-6) - raise ValueError(f"Invalid normalization type: {normtype}") - - class LTX2AudioAttnBlock(nn.Module): def __init__( self, @@ -164,7 +104,12 @@ def __init__( super().__init__() self.in_channels = in_channels - self.norm = build_normalization_layer(in_channels, normtype=norm_type) + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) @@ -211,23 +156,49 @@ def __init__( self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut - self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.non_linearity = nn.SiLU() - self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = nn.Linear(temb_channels, out_channels) - self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") self.dropout = nn.Dropout(dropout) - self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis - ) + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = make_conv2d( - in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis - ) + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: h = self.norm1(x) @@ -254,7 +225,12 @@ def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[s self.with_conv = with_conv self.causality_axis = causality_axis if self.with_conv: - self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -273,26 +249,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x - -class LTX2AudioPerChannelStatistics(nn.Module): - """ - Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over - the entire dataset and stored in model's checkpoint under AudioVAE state_dict - """ - - def __init__(self, latent_channels: int = 128) -> None: - super().__init__() - # Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`? - self.register_buffer("std-of-means", torch.empty(latent_channels)) - self.register_buffer("mean-of-means", torch.empty(latent_channels)) - - def denormalize(self, x: torch.Tensor) -> torch.Tensor: - return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) - - def normalize(self, x: torch.Tensor) -> torch.Tensor: - return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) - - class LTX2AudioAudioPatchifier: """ Patchifier for spectrogram/audio latents. @@ -316,11 +272,9 @@ def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: batch, channels, time, freq = audio_latents.shape return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) - def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor: + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: batch, time, _ = audio_latents.shape - channels = output_shape.channels - freq = output_shape.mel_bins - return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3) + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) @property def patch_size(self) -> Tuple[int, int, int]: @@ -356,9 +310,6 @@ def __init__( ) -> None: super().__init__() - resolved_causality_axis = _resolve_causality_axis(causality_axis) - - self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels) self.sample_rate = sample_rate self.mel_hop_length = mel_hop_length self.is_causal = is_causal @@ -384,116 +335,43 @@ def __init__( self.latent_channels = latent_channels self.channel_multipliers = ch_mult self.attn_resolutions = attn_resolutions - self.causality_axis = resolved_causality_axis + self.causality_axis = causality_axis base_block_channels = base_channels * self.channel_multipliers[-1] base_resolution = resolution // (2 ** (self.num_resolutions - 1)) self.z_shape = (1, latent_channels, base_resolution, base_resolution) - self.conv_in = make_conv2d( - latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis - ) - self.non_linearity = nn.SiLU() - self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention) - self.up, final_block_channels = self._build_up_path( - initial_block_channels=base_block_channels, - dropout=dropout, - resamp_with_conv=True, - ) - - self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) - self.conv_out = make_conv2d( - final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis - ) - - def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor: - _, _, current_time, current_freq = decoded_output.shape - target_channels = target_shape.channels - target_time = target_shape.frames - target_freq = target_shape.mel_bins - - decoded_output = decoded_output[ - :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) - ] - - time_padding_needed = target_time - decoded_output.shape[2] - freq_padding_needed = target_freq - decoded_output.shape[3] - - if time_padding_needed > 0 or freq_padding_needed > 0: - padding = ( - 0, - max(freq_padding_needed, 0), - 0, - max(time_padding_needed, 0), - ) - decoded_output = F.pad(decoded_output, padding) - - decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] - - return decoded_output - - def forward( - self, - sample: torch.Tensor, - ) -> torch.Tensor: - latent_shape = AudioLatentShape( - batch=sample.shape[0], - channels=sample.shape[1], - frames=sample.shape[2], - mel_bins=sample.shape[3], - ) - - sample_patched = self.patchifier.patchify(sample) - sample_denormalized = self.per_channel_statistics.denormalize(sample_patched) - sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) - - target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR - if self.causality_axis is not None: - target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) - - target_shape = AudioLatentShape( - batch=latent_shape.batch, - channels=self.out_ch, - frames=target_frames, - mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, - ) - - hidden_features = self.conv_in(sample) - hidden_features = self._run_mid_layers(hidden_features) - hidden_features = self._run_upsampling_path(hidden_features) - decoded_output = self._finalize_output(hidden_features) - - decoded_output = self._adjust_output_shape(decoded_output, target_shape) - - return decoded_output - - def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module: - mid = nn.Module() - mid.block_1 = LTX2AudioResnetBlock( - in_channels=channels, - out_channels=channels, + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, temb_channels=self.temb_ch, dropout=dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, ) - mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity() - mid.block_2 = LTX2AudioResnetBlock( - in_channels=channels, - out_channels=channels, + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, temb_channels=self.temb_ch, dropout=dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, ) - return mid - def _build_up_path( - self, initial_block_channels: int, dropout: float, resamp_with_conv: bool - ) -> tuple[nn.ModuleList, int]: - up_modules = nn.ModuleList() - block_in = initial_block_channels + self.up = nn.ModuleList() + block_in = base_block_channels curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) for level in reversed(range(self.num_resolutions)): @@ -519,39 +397,89 @@ def _build_up_path( stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) if level != 0: - stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis) + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) curr_res *= 2 - up_modules.insert(0, stage) + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True + ) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) - return up_modules, block_in + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape - def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor: - features = self.mid.block_1(features, temb=None) - features = self.mid.attn_1(features) - return self.mid.block_2(features, temb=None) + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) - def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor: for level in reversed(range(self.num_resolutions)): stage = self.up[level] for block_idx, block in enumerate(stage.block): - features = block(features, temb=None) + hidden_features = block(hidden_features, temb=None) if stage.attn: - features = stage.attn[block_idx](features) + hidden_features = stage.attn[block_idx](hidden_features) if level != 0 and hasattr(stage, "upsample"): - features = stage.upsample(features) + hidden_features = stage.upsample(hidden_features) - return features - - def _finalize_output(self, features: torch.Tensor) -> torch.Tensor: if self.give_pre_end: - return features + return hidden_features - hidden = self.norm_out(features) + hidden = self.norm_out(hidden_features) hidden = self.non_linearity(hidden) - decoded = self.conv_out(hidden) - return torch.tanh(decoded) if self.tanh_out else decoded + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): @@ -583,7 +511,10 @@ def __init__( ) -> None: super().__init__() - resolved_causality_axis = _resolve_causality_axis(causality_axis) + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions self.decoder = LTX2AudioDecoder( @@ -596,7 +527,7 @@ def __init__( resolution=resolution, latent_channels=latent_channels, norm_type=norm_type, - causality_axis=resolved_causality_axis, + causality_axis=causality_axis, dropout=dropout, mid_block_add_attention=mid_block_add_attention, sample_rate=sample_rate, @@ -605,6 +536,13 @@ def __init__( mel_bins=mel_bins, ) + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.zeros((base_channels, )) + latents_mean = torch.ones((base_channels, )) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + # TODO: calculate programmatically instead of hardcoding self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 # TODO: confirm whether the mel compression ratio below is correct diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index a4ee5cb150a5..fe31d02ec00f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -516,6 +516,12 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + @staticmethod def _pack_audio_latents( latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None @@ -1038,10 +1044,11 @@ def __call__( video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) audio_latents = audio_latents.to(self.audio_vae.dtype) - # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's - # decode method + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) From 46822c43dbe9dde816f467c0ca4aa6fb126f5998 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 30 Dec 2025 09:06:07 +0530 Subject: [PATCH 48/86] Add support for I2V (#8) * start i2v. * up * up * up * up * up * remove uniform strategy code. * remove unneeded code. --- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_ltx2.py | 13 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/ltx2/__init__.py | 2 + src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 10 - .../ltx2/pipeline_ltx2_image2video.py | 1138 +++++++++++++++++ 6 files changed, 1152 insertions(+), 17 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ea429c2e4115..2e99ea8063a4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -538,6 +538,7 @@ "LTXLatentUpsamplePipeline", "LTXPipeline", "LTX2Pipeline", + "LTX2ImageToVideoPipeline", "LucyEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -1245,6 +1246,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, LTX2Pipeline, + LTX2ImageToVideoPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 3d2d079608ba..1f685fdc3a81 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1051,6 +1051,7 @@ def forward( encoder_hidden_states: torch.Tensor, audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, + audio_timestep: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, @@ -1073,8 +1074,7 @@ def forward( Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels). encoder_hidden_states (`torch.Tensor`): Input text embeddings of shape TODO. - timesteps (`torch.Tensor`): - Timestep information of shape (batch_size, num_train_timesteps). + TODO for the rest. Returns: `AudioVisualModelOutput` or `tuple`: @@ -1097,6 +1097,9 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 @@ -1143,7 +1146,7 @@ def forward( embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) temb_audio, audio_embedded_timestep = self.audio_time_embed( - timestep.flatten(), + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) @@ -1165,12 +1168,12 @@ def forward( video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - timestep.flatten(), + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef9430043bed..eaf444d5ec37 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -288,7 +288,7 @@ "LTXConditionPipeline", "LTXLatentUpsamplePipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline"] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -720,7 +720,7 @@ LEditsPPPipelineStableDiffusionXL, ) from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline - from .ltx2 import LTX2Pipeline + from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index d23123089fb8..a97c836e0c7d 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] _import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"] _import_structure["vocoder"] = ["LTX2Vocoder"] @@ -35,6 +36,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index fe31d02ec00f..2617e5cacb64 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -496,16 +496,6 @@ def _unpack_latents( latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents - @staticmethod - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - @staticmethod def _denormalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 000000000000..9f0755bb3144 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,1138 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Any, Callable, Dict, List, Optional, Union +import inspect +import numpy as np +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LTX2PipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .text_encoder import LTX2AudioVisualTextEncoder +from .vocoder import LTX2Vocoder +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from transformers import GemmaTokenizer, GemmaTokenizerFast +from ...video_processor import VideoProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=121, + ... num_inference_steps=40, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: LTX2AudioVisualTextEncoder, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.base_text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask.to(device), + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + _, audio_seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 64, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + duration_s = num_frames / frame_rate + latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latent_length = int(duration_s * latents_per_second) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 25.0, + num_inference_steps: int = 40, + timesteps: List[int] = None, + guidance_scale: float = 3.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `25.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + audio_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + audio_prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_audio_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + audio_prompt_embeds=audio_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_audio_prompt_embeds=negative_audio_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, # Video frames, audio frames will be calculated from this + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device, fps=frame_rate + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + audio_encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + audio_latents = audio_latents.to(self.audio_vae.dtype) + # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's + # decode method + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) From bd607b97a8c23f17a6b13e98f0b6b93576e8c409 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Tue, 30 Dec 2025 19:53:35 -0800 Subject: [PATCH 49/86] Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11) --- .../pipelines/ltx2/pipeline_ltx2_image2video.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 9f0755bb3144..359e665d4b28 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -565,7 +565,14 @@ def _unpack_audio_latents( # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents - + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + def prepare_latents( self, image: Optional[torch.Tensor] = None, @@ -1122,10 +1129,11 @@ def __call__( video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) audio_latents = audio_latents.to(self.audio_vae.dtype) - # NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's - # decode method + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) From d3f10fe54e8a88ac941e1f547c9cac7b17c59b29 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 31 Dec 2025 09:36:48 +0530 Subject: [PATCH 50/86] test i2v. --- scripts/ltx2_test_full_pipeline_i2v.py | 206 +++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 scripts/ltx2_test_full_pipeline_i2v.py diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py new file mode 100644 index 000000000000..01b18e5eb814 --- /dev/null +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -0,0 +1,206 @@ + +import argparse +import os +from fractions import Fraction +from typing import Optional +from PIL import Image + +import av # Needs to be installed separately (`pip install av`) +import torch + +from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline + + +# Video export functions copied from original LTX 2.0 code +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str +) -> None: + video_np = video.cpu().numpy() + + _, height, width, _ = video_np.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") + parser.add_argument("--revision", type=str, default="main") + + parser.add_argument("--image_path", required=True, type=str) + parser.add_argument( + "--prompt", + type=str, + default="An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + ) + + parser.add_argument("--num_inference_steps", type=int, default=40) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=768) + parser.add_argument("--num_frames", type=int, default=121) + parser.add_argument("--frame_rate", type=float, default=25.0) + parser.add_argument("--guidance_scale", type=float, default=3.0) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--dtype", type=str, default="bf16") + parser.add_argument("--cpu_offload", action="store_true") + + parser.add_argument( + "--output_dir", + type=str, + default="samples", + help="Output directory for generated video", + ) + parser.add_argument( + "--output_filename", + type=str, + default="ltx2_sample_video.mp4", + help="Filename of the exported generated video", + ) + + args = parser.parse_args() + args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + return args + + +def main(args): + pipeline = LTX2ImageToVideoPipeline.from_pretrained( + args.model_id, revision=args.revision, torch_dtype=args.dtype, + ) + if args.cpu_offload: + pipeline.enable_model_cpu_offload() + else: + pipeline.to(device=args.device) + + video, audio = pipeline( + image=Image.open(args.image_path), + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device=args.device).manual_seed(args.seed), + output_type="np", + return_dict=False, + ) + + # Convert video to uint8 (but keep as NumPy array) + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + + encode_video( + video[0], + fps=args.frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000 + output_path=os.path.join(args.output_dir, args.output_filename), + ) + + +if __name__ == '__main__': + args = parse_args() + main(args) From caae16768a240def1a366d8173d1c4e825bfc5c8 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 5 Jan 2026 06:41:13 -0800 Subject: [PATCH 51/86] Move Video and Audio Text Encoder Connectors to Transformer (#12) * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * precompute run_connectors,. * fixes * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * Make connectors a separate module (#18) * remove text_encoder.py * address yiyi's comments. * up * up * up * up --------- Co-authored-by: sayakpaul --- scripts/convert_ltx2_to_diffusers.py | 232 ++++--- .../models/transformers/transformer_ltx2.py | 30 +- src/diffusers/pipelines/ltx2/__init__.py | 4 +- src/diffusers/pipelines/ltx2/connectors.py | 281 ++++++++ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 136 +++- .../ltx2/pipeline_ltx2_image2video.py | 137 +++- src/diffusers/pipelines/ltx2/text_encoder.py | 625 ------------------ .../test_models_transformer_ltx2.py | 1 + 8 files changed, 630 insertions(+), 816 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/connectors.py delete mode 100644 src/diffusers/pipelines/ltx2/text_encoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index eb311c3bc0ce..9f58d8f344ce 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,18 +8,11 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoModel, AutoTokenizer - -from diffusers import ( - AutoencoderKLLTX2Audio, - AutoencoderKLLTX2Video, - FlowMatchEulerDiscreteScheduler, - LTX2Pipeline, - LTX2VideoTransformer3DModel, -) +from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel +from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available -from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder CTX = init_empty_weights if is_accelerate_available() else nullcontext @@ -134,6 +127,17 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str "adaln_single": convert_ltx2_transformer_adaln_single, } +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_inplace, "per_channel_statistics.mean-of-stds": remove_keys_inplace, @@ -146,7 +150,27 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} -LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {} + +def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection.aggregate_embed", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: @@ -240,32 +264,109 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, return config, rename_dict, special_keys_remap +def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "caption_channels": 16, + "text_proj_in_factor": 3, + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + }, + } + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + }, + } + + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = {} + + return config, rename_dict, special_keys_remap + + def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) diffusers_config = config["diffusers_config"] + transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict) + with init_empty_weights(): transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) # Handle official code --> diffusers key remapping via the remap dict - for key in list(original_state_dict.keys()): + for key in list(transformer_state_dict.keys()): new_key = key[:] for replace_key, rename_key in rename_dict.items(): new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in # special_keys_remap - for key in list(original_state_dict.keys()): + for key in list(transformer_state_dict.keys()): for special_key, handler_fn_inplace in special_keys_remap.items(): if special_key not in key: continue - handler_fn_inplace(key, original_state_dict) + handler_fn_inplace(key, transformer_state_dict) - transformer.load_state_dict(original_state_dict, strict=True, assign=True) + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) return transformer +def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version) + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { @@ -471,81 +572,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder -def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: - if version == "2.0": - config = { - "model_id": "diffusers-internal-dev/new-ltx-model", - "diffusers_config": { - "text_encoder_hidden_dim": 3840, - "text_proj_in_factor": 49, - "video_connector_num_attention_heads": 30, - "video_connector_attention_head_dim": 128, - "video_connector_num_layers": 2, - "video_connector_num_learnable_registers": 128, - "audio_connector_num_attention_heads": 30, - "audio_connector_attention_head_dim": 128, - "audio_connector_num_layers": 2, - "audio_connector_num_learnable_registers": 128, - "rope_base_seq_len": 4096, - "rope_theta": 10000.0, - "rope_double_precision": True, - "causal_temporal_positioning": False, - }, - } - rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT - special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP - return config, rename_dict, special_keys_remap - - -def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."): - model_state_dict = {} - - model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"] - - text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"] - for param_name, param in combined_ckpt.items(): - if param_name.startswith(prefix): - new_param_name = param_name.replace(prefix, "") - for submodule_name in text_encoder_submodules: - if new_param_name.startswith(submodule_name): - model_state_dict[new_param_name] = param - break - - return model_state_dict - - -def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version) - diffusers_config = config["diffusers_config"] - diffusers_config["text_model_id"] = text_model_id - diffusers_config["config_only"] = True - - with init_empty_weights(): - text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config) - - # Handle official code --> diffusers key remapping via the remap dict - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in rename_dict.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) - - # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in - # special_keys_remap - for key in list(original_state_dict.keys()): - for special_key, handler_fn_inplace in special_keys_remap.items(): - if special_key not in key: - continue - handler_fn_inplace(key, original_state_dict) - - base_text_model = AutoModel.from_pretrained(text_model_id) - base_text_model_state_dict= base_text_model.state_dict() - base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()} - combined_state_dict = {**original_state_dict, **base_text_model_state_dict} - - text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True) - return text_encoder - def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: @@ -588,6 +614,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi for param_name, param in combined_ckpt.items(): if param_name.startswith(prefix): model_state_dict[param_name.replace(prefix, "")] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_key = "text_embedding_projection.aggregate_embed.weight" + if connector_key in combined_ckpt and connector_key not in model_state_dict: + model_state_dict[connector_key] = combined_ckpt[connector_key] + return model_state_dict @@ -649,6 +682,7 @@ def get_args(): parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") parser.add_argument( @@ -721,6 +755,15 @@ def main(args): transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) if not args.full_pipeline: transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.connectors or args.full_pipeline: + if args.dit_filename is not None: + original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version) + if not args.full_pipeline: + connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors")) if args.vocoder or args.full_pipeline: if args.vocoder_filename is not None: @@ -732,8 +775,8 @@ def main(args): vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) if args.text_encoder or args.full_pipeline: - text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt) - text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id) + # text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id) if not args.full_pipeline: text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) @@ -758,6 +801,7 @@ def main(args): audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, ) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 1f685fdc3a81..d0e5da2390f9 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -14,11 +14,9 @@ # limitations under the License. import inspect -import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn @@ -780,28 +778,12 @@ def forward( num_rope_elems = num_pos_dims * 2 # 4. Create a 1D grid of frequencies for RoPE - start = 1.0 - end = self.theta - if self.double_precision: - pow_indices = np.power( - self.theta, - np.linspace( - np.log(start) / np.log(self.theta), - np.log(end) / np.log(self.theta), - self.dim // num_rope_elems, - dtype=np.float64, - ), - ) - freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) - else: - freqs = self.theta ** torch.linspace( - start=math.log(start, self.theta), - end=math.log(end, self.theta), - steps=self.dim // num_rope_elems, - device=device, - dtype=torch.float32, - ) - freqs = freqs * math.pi / 2.0 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape # (self.dim // num_elems,) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index a97c836e0c7d..95d5f8d4a445 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -24,7 +24,7 @@ else: _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] - _import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"] + _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -37,7 +37,7 @@ else: from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline - from .text_encoder import LTX2AudioVisualTextEncoder + from .connectors import LTX2TextConnectors from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py new file mode 100644 index 000000000000..ce4dc4494f29 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -0,0 +1,281 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + ): + super().__init__() + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and + audio streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 2617e5cacb64..08fad91c4188 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -29,8 +29,8 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder @@ -192,9 +192,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix tokenizer (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. """ - model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -203,8 +205,9 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, audio_vae: AutoencoderKLLTX2Audio, - text_encoder: LTX2AudioVisualTextEncoder, + text_encoder: Gemma3ForConditionalGeneration, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, vocoder: LTX2Vocoder, ): @@ -215,6 +218,7 @@ def __init__( audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, scheduler=scheduler, @@ -252,6 +256,73 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], @@ -274,7 +345,7 @@ def _get_gemma_prompt_embeds( max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ device = device or self._execution_device - dtype = dtype or self.text_encoder.base_text_encoder.dtype + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -296,29 +367,34 @@ def _get_gemma_prompt_embeds( ) text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask.to(device), + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, padding_side=self.tokenizer.padding_side, scale_factor=scale_factor, ) prompt_embeds = prompt_embeds.to(dtype=dtype) - audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - _, audio_seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + return prompt_embeds, prompt_attention_mask def encode_prompt( self, @@ -327,9 +403,7 @@ def encode_prompt( do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 1024, @@ -372,7 +446,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -397,7 +471,7 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -406,7 +480,7 @@ def encode_prompt( dtype=dtype, ) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask def check_inputs( self, @@ -668,10 +742,8 @@ def __call__( latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -732,17 +804,11 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - audio_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -812,10 +878,8 @@ def __call__( # 3. Prepare text embeddings ( prompt_embeds, - audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, - negative_audio_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, @@ -823,9 +887,7 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, - audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, @@ -833,9 +895,13 @@ def __call__( ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + # 4. Prepare latent variables latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio @@ -939,11 +1005,11 @@ def __call__( noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - audio_encoder_attention_mask=prompt_attention_mask, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, height=latent_height, width=latent_width, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 359e665d4b28..caad9a1767d3 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -23,14 +23,14 @@ from ...image_processor import PipelineImageInput from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput from ..pipeline_utils import DiffusionPipeline -from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel -from transformers import GemmaTokenizer, GemmaTokenizerFast +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast from ...video_processor import VideoProcessor @@ -196,7 +196,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL TODO """ - model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -205,8 +205,9 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, audio_vae: AutoencoderKLLTX2Audio, - text_encoder: LTX2AudioVisualTextEncoder, + text_encoder: Gemma3ForConditionalGeneration, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, vocoder: LTX2Vocoder, ): @@ -217,6 +218,7 @@ def __init__( audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, scheduler=scheduler, @@ -254,6 +256,74 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -277,7 +347,7 @@ def _get_gemma_prompt_embeds( max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ device = device or self._execution_device - dtype = dtype or self.text_encoder.base_text_encoder.dtype + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -299,29 +369,34 @@ def _get_gemma_prompt_embeds( ) text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask.to(device), + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, padding_side=self.tokenizer.padding_side, scale_factor=scale_factor, ) prompt_embeds = prompt_embeds.to(dtype=dtype) - audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - _, audio_seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + return prompt_embeds, prompt_attention_mask # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt def encode_prompt( @@ -331,9 +406,7 @@ def encode_prompt( do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 1024, @@ -376,7 +449,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -401,7 +474,7 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -410,7 +483,7 @@ def encode_prompt( dtype=dtype, ) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs def check_inputs( @@ -727,10 +800,8 @@ def __call__( latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -793,17 +864,11 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - audio_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -873,10 +938,8 @@ def __call__( # 3. Prepare text embeddings ( prompt_embeds, - audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, - negative_audio_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, @@ -884,9 +947,7 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, - audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, @@ -894,9 +955,13 @@ def __call__( ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + # 4. Prepare latent variables if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) @@ -1008,12 +1073,12 @@ def __call__( noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - audio_encoder_attention_mask=prompt_attention_mask, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, height=latent_height, width=latent_width, diff --git a/src/diffusers/pipelines/ltx2/text_encoder.py b/src/diffusers/pipelines/ltx2/text_encoder.py deleted file mode 100644 index f15fa62224d2..000000000000 --- a/src/diffusers/pipelines/ltx2/text_encoder.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2025 The Lightricks team and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ...models.attention_dispatch import dispatch_attention_fn -from ...models.embeddings import get_1d_rotary_pos_embed -from ...models.modeling_utils import ModelMixin -from ...utils import is_torch_version, logging -from ..pipeline_loading_utils import _fetch_class_library_tuple - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - cos, sin = freqs - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - - -# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor -class LTX2AudioVideoAttnProcessor: - r""" - Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. - Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can - support audio-to-video (a2v) and video-to-audio (v2a) cross attention. - """ - - _attention_backend = None - _parallel_config = None - - def __init__(self): - if is_torch_version("<", "2.0"): - raise ValueError( - "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." - ) - - def __call__( - self, - attn: "LTX2Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.norm_q(query) - key = attn.norm_k(key) - - if query_rotary_emb is not None: - query = apply_rotary_emb(query, query_rotary_emb) - key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) - - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention -class LTX2Attention(torch.nn.Module, AttentionModuleMixin): - r""" - Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key - RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. - """ - - _default_processor_cls = LTX2AudioVideoAttnProcessor - _available_processors = [LTX2AudioVideoAttnProcessor] - - def __init__( - self, - query_dim: int, - heads: int = 8, - kv_heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = True, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - qk_norm: str = "rms_norm_across_heads", - norm_eps: float = 1e-6, - norm_elementwise_affine: bool = True, - processor=None, - ): - super().__init__() - if qk_norm != "rms_norm_across_heads": - raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") - - self.head_dim = dim_head - self.inner_dim = dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.use_bias = bias - self.dropout = dropout - self.out_dim = query_dim - self.heads = heads - - self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(torch.nn.Dropout(dropout)) - - if processor is None: - processor = self._default_processor_cls() - self.set_processor(processor) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] - if len(unused_kwargs) > 0: - logger.warning( - f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - hidden_states = self.processor( - self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs - ) - return hidden_states - - -class LTX2RotaryPosEmbed1d(nn.Module): - """ - 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. - """ - - def __init__( - self, - dim: int, - base_seq_len: int = 4096, - theta: float = 10000.0, - double_precision: bool = True, - ): - super().__init__() - self.dim = dim - self.base_seq_len = base_seq_len - self.theta = theta - self.double_precision = double_precision - - def forward( - self, - batch_size: int, - pos: int, - device: Union[str, torch.device], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Get 1D position ids - grid_1d = torch.arange(pos, dtype=torch.float32, device=device) - # Get fractional indices relative to self.base_seq_len - grid_1d = grid_1d / self.base_seq_len - grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] - - # 2. Calculate 1D RoPE frequencies - num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 - start = 1.0 - end = self.theta - if self.double_precision: - pow_indices = np.power( - self.theta, - np.linspace( - np.log(start) / np.log(self.theta), - np.log(end) / np.log(self.theta), - self.dim // num_rope_elems, - dtype=np.float64, - ), - ) - freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) - else: - freqs = self.theta ** torch.linspace( - start=math.log(start, self.theta), - end=math.log(end, self.theta), - steps=self.dim // num_rope_elems, - device=device, - dtype=torch.float32, - ) - freqs = freqs * math.pi / 2.0 - - # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape - # (self.dim // 2,). - freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] - - # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim - cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) - sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) - - if self.dim % num_rope_elems != 0: - cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) - cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) - sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - - return cos_freqs, sin_freqs - - -class LTX2TransformerBlock1d(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - activation_fn: str = "gelu-approximate", - eps: float = 1e-6, - ): - super().__init__() - - self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.attn1 = LTX2Attention( - query_dim=dim, - heads=num_attention_heads, - kv_heads=num_attention_heads, - dim_head=attention_head_dim, - processor=LTX2AudioVideoAttnProcessor(), - ) - - self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.ff = FeedForward(dim, activation_fn=activation_fn) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - norm_hidden_states = self.norm1(hidden_states) - attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) - hidden_states = hidden_states + attn_hidden_states - - norm_hidden_states = self.norm2(hidden_states) - ff_hidden_states = self.ff(norm_hidden_states) - hidden_states = hidden_states + ff_hidden_states - - return hidden_states - - -class LTX2ConnectorTransformer1d(nn.Module): - """ - A 1D sequence transformer for modalities such as text. - - In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. - """ - _supports_gradient_checkpointing = True - - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 128, - num_layers: int = 2, - num_learnable_registers: Optional[int] = 128, - rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - eps: float = 1e-6, - causal_temporal_positioning: bool = False, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.inner_dim = num_attention_heads * attention_head_dim - self.causal_temporal_positioning = causal_temporal_positioning - - self.num_learnable_registers = num_learnable_registers - self.learnable_registers = None - if num_learnable_registers is not None: - init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 - self.learnable_registers = torch.nn.Parameter(init_registers) - - self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision - ) - - self.transformer_blocks = torch.nn.ModuleList( - [ - LTX2TransformerBlock1d( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for _ in range(num_layers) - ] - ) - - self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) - - self.gradient_checkpointing = False - - def forward( - self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # hidden_states shape: [batch_size, seq_len, hidden_dim] - # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] - batch_size, seq_len, _ = hidden_states.shape - - # 1. Replace padding with learned registers, if using - if self.learnable_registers is not None: - if seq_len % self.num_learnable_registers != 0: - raise ValueError( - f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" - f" of learnable registers {self.num_learnable_registers}" - ) - - num_register_repeats = seq_len // self.num_learnable_registers - registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] - - binary_attn_mask = (attention_mask >= -9000.0).int() - if binary_attn_mask.ndim == 4: - binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - - hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] - valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] - pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] - padded_hidden_states = [ - F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) - ] - padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] - - flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] - hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers - - # Overwrite attention_mask with an all-zeros mask if using registers. - attention_mask = torch.zeros_like(attention_mask) - - # 2. Calculate 1D RoPE positional embeddings - rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) - - # 3. Run 1D transformer blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) - else: - hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) - - hidden_states = self.norm_out(hidden_states) - - return hidden_states, attention_mask - - -class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin): - ignore_for_config = ["text_model"] - - @register_to_config - def __init__( - self, - text_model: Optional[Gemma3ForConditionalGeneration] = None, - text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", - text_encoder_hidden_dim: Optional[int] = 3840, - text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1 - video_connector_num_attention_heads: int = 30, - video_connector_attention_head_dim: int = 128, - video_connector_num_layers: int = 2, - video_connector_num_learnable_registers: int = 128, - audio_connector_num_attention_heads: int = 30, - audio_connector_attention_head_dim: int = 128, - audio_connector_num_layers: int = 2, - audio_connector_num_learnable_registers: Optional[int] = 128, - rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - causal_temporal_positioning: bool = False, - config_only: bool = True, - ): - super().__init__() - if text_model is None: - self.set_base_text_encoder(text_model_id, config_only=config_only) - else: - self.base_text_encoder = text_model - - if text_encoder_hidden_dim is None: - if hasattr(self.base_text_encoder, "config"): - if hasattr(self.base_text_encoder.config, "hidden_size"): - text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None) - elif hasattr(self.base_text_encoder.config, "text_config"): - text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None) - if text_encoder_hidden_dim is None: - raise ValueError( - "`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it." - ) - - if text_proj_in_factor is None: - num_layers = None - if hasattr(self.base_text_encoder, "config"): - if hasattr(self.base_text_encoder.config, "num_hidden_layers"): - num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None) - elif hasattr(self.base_text_encoder.config, "text_config"): - num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None) - if num_layers is None: - raise ValueError( - "`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it." - ) - text_proj_in_factor = num_layers + 1 - - self.text_proj_in = nn.Linear( - text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False - ) - - self.video_connector = LTX2ConnectorTransformer1d( - num_attention_heads=video_connector_num_attention_heads, - attention_head_dim=video_connector_attention_head_dim, - num_layers=video_connector_num_layers, - num_learnable_registers=video_connector_num_learnable_registers, - rope_base_seq_len=rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - ) - self.audio_connector = LTX2ConnectorTransformer1d( - num_attention_heads=audio_connector_num_attention_heads, - attention_head_dim=audio_connector_attention_head_dim, - num_layers=audio_connector_num_layers, - num_learnable_registers=audio_connector_num_learnable_registers, - rope_base_seq_len=rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - ) - - def set_base_text_encoder( - self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True - ): - if config_only: - base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id) - base_text_encoder = AutoModel.from_config(base_text_encoder_config) - else: - base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id) - self.base_text_encoder = base_text_encoder - - @staticmethod - def pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: Union[str, torch.device], - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - - def run_connectors( - self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states. - - Args: - text_encoder_hidden_states (`torch.Tensor`): - Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`. - attention_mask (`torch.Tensor`): - Attention mask of shape `(batch_size, seq_len)`. - - Returns: - `Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`: - Returns a 3-tuple of tensors where the first element is the video text embeddings of shape - `(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape - `(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape - `(batch_size, seq_len)`. - """ - # Convert to additive attention mask - text_dtype = text_encoder_hidden_states.dtype - connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max - - text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) - - video_text_embedding, new_attn_mask = self.video_connector( - text_encoder_hidden_states, connector_attn_mask - ) - - attn_mask = (new_attn_mask < 1e-6).to(torch.int64) - attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * attn_mask - new_attn_mask = attn_mask.squeeze(-1) - - audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask) - - return video_text_embedding, audio_text_embedding, new_attn_mask - - def forward( - self, - text_input_ids, - attention_mask: Optional[torch.Tensor] = None, - padding_side: str = "left", - scale_factor: int = 8, - ): - text_encoder_outputs = self.base_text_encoder( - input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True - ) - - text_encoder_hidden_states = text_encoder_outputs.hidden_states - text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = attention_mask.sum(dim=-1) - - text_encoder_hidden_states = self.pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=text_encoder_hidden_states.device, - padding_side=padding_side, - scale_factor=scale_factor, - ) - - video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors( - text_encoder_hidden_states, attention_mask - ) - - return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 079273e975e5..1b0a7dd28f26 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -99,6 +99,7 @@ def prepare_init_args_and_inputs_for_common(self): "num_layers": 2, "qk_norm": "rms_norm_across_heads", "caption_channels": 16, + "rope_double_precision": False, } inputs_dict = self.dummy_input return init_dict, inputs_dict From 0be4f31620122594a7502f6aca660611b54fa6e9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 5 Jan 2026 21:13:01 +0530 Subject: [PATCH 52/86] up (#19) --- scripts/convert_ltx2_to_diffusers.py | 3 +-- scripts/ltx2_test_full_pipeline.py | 12 +----------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 9f58d8f344ce..039761134fbd 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -786,8 +786,7 @@ def main(args): if args.full_pipeline: scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=False, - shift=math.exp(2.05), # Equivalent to dynamic shift if always using max_image_seq_len + use_dynamic_shifting=True, base_shift=0.95, max_shift=2.05, base_image_seq_len=1024, diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 37a649d5ea74..5f0f366e714e 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -7,7 +7,7 @@ import av # Needs to be installed separately (`pip install av`) import torch -from diffusers import LTX2Pipeline, FlowMatchEulerDiscreteScheduler +from diffusers import LTX2Pipeline # Video export functions copied from original LTX 2.0 code @@ -151,7 +151,6 @@ def parse_args(): parser.add_argument("--frame_rate", type=float, default=25.0) parser.add_argument("--guidance_scale", type=float, default=3.0) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--apply_scheduler_fix", action="store_true") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--dtype", type=str, default="bf16") @@ -181,15 +180,6 @@ def main(args): revision=args.revision, torch_dtype=args.dtype, ) - if args.apply_scheduler_fix: - max_shift = pipeline.scheduler.config.max_shift - time_shift_type = pipeline.scheduler.config.time_shift_type - fixed_scheduler = FlowMatchEulerDiscreteScheduler.from_config( - pipeline.scheduler.config, - use_dynamic_shifting=False, - shift=math.exp(max_shift) if time_shift_type == "exponential" else max_shift, - ) - pipeline.scheduler = fixed_scheduler pipeline.to(device=args.device) if args.cpu_offload: pipeline.enable_model_cpu_offload() From c5b52d6c9f7a5f17120eec36419da1c1b38979aa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 5 Jan 2026 21:13:10 +0530 Subject: [PATCH 53/86] address initial feedback from lightricks team (#16) * cross_attn_timestep_scale_multiplier to 1000 * implement split rope type. * up * propagate rope_type to rope embed classes as well. * up --- scripts/convert_ltx2_to_diffusers.py | 4 +- .../models/transformers/transformer_ltx2.py | 115 ++++++++++++++++-- src/diffusers/pipelines/ltx2/connectors.py | 63 ++++++++-- 3 files changed, 160 insertions(+), 22 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 039761134fbd..3b8c9598b513 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -256,7 +256,8 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "rope_double_precision": True, "causal_offset": 1, "timestep_scale_multiplier": 1000, - "cross_attn_timestep_scale_multiplier": 1, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split" }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT @@ -303,6 +304,7 @@ def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "rope_theta": 10000.0, "rope_double_precision": True, "causal_temporal_positioning": False, + "rope_type": "split", }, } diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index d0e5da2390f9..2f31054319a8 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -37,13 +37,53 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = freqs x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out +def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + b, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r) + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + return out + + +ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb} @dataclass class AudioVisualModelOutput(BaseOutput): @@ -147,8 +187,8 @@ def __call__( key = attn.norm_k(key) if query_rotary_emb is not None: - query = apply_rotary_emb(query, query_rotary_emb) - key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb) + key = ROTARY_FN_MAP[attn.rope_type](key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) @@ -194,6 +234,7 @@ def __init__( qk_norm: str = "rms_norm_across_heads", norm_eps: float = 1e-6, norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", processor=None, ): super().__init__() @@ -209,6 +250,7 @@ def __init__( self.dropout = dropout self.out_dim = query_dim self.heads = heads + self.rope_type = rope_type self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) @@ -280,6 +322,7 @@ def __init__( attention_out_bias: bool = True, eps: float = 1e-6, elementwise_affine: bool = False, + rope_type: str = "interleaved", ): super().__init__() @@ -294,6 +337,7 @@ def __init__( cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type, ) self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -306,6 +350,7 @@ def __init__( cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type, ) # 2. Prompt Cross-Attention @@ -319,6 +364,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -331,6 +377,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -345,6 +392,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -358,6 +406,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, + rope_type=rope_type, ) # 4. Feedforward layers @@ -561,6 +610,8 @@ def __init__( causal_offset: int = 1, modality: str = "video", double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, ) -> None: super().__init__() @@ -568,7 +619,12 @@ def __init__( self.patch_size = patch_size self.patch_size_t = patch_size_t + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads # Video-specific self.base_height = base_height @@ -791,14 +847,41 @@ def forward( freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim - cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) - sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + # TODO: consider implementing this as a utility and reuse in `connectors.py`. + # src/diffusers/pipelines/ltx2/connectors.py + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) - if self.dim % num_rope_elems != 0: - cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) - sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) return cos_freqs, sin_freqs @@ -885,7 +968,8 @@ def __init__( rope_double_precision: bool = True, causal_offset: int = 1, timestep_scale_multiplier: int = 1000, - cross_attn_timestep_scale_multiplier: int = 1, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", ) -> None: super().__init__() @@ -952,6 +1036,8 @@ def __init__( causal_offset=causal_offset, modality="video", double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, ) self.audio_rope = LTX2AudioVideoRotaryPosEmbed( dim=audio_inner_dim, @@ -965,6 +1051,8 @@ def __init__( causal_offset=causal_offset, modality="audio", double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, ) # Audio-to-Video, Video-to-Audio Cross-Attention @@ -980,6 +1068,8 @@ def __init__( causal_offset=causal_offset, modality="video", double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, ) self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( dim=audio_cross_attention_dim, @@ -992,6 +1082,8 @@ def __init__( causal_offset=causal_offset, modality="audio", double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads ) # 5. Transformer Blocks @@ -1012,6 +1104,7 @@ def __init__( attention_out_bias=attention_out_bias, eps=norm_eps, elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, ) for _ in range(num_layers) ] diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index ce4dc4494f29..c146c9833e71 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -9,7 +9,6 @@ from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor - class LTX2RotaryPosEmbed1d(nn.Module): """ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. @@ -21,12 +20,19 @@ def __init__( base_seq_len: int = 4096, theta: float = 10000.0, double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32 ): super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.dim = dim self.base_seq_len = base_seq_len self.theta = theta self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads def forward( self, @@ -54,14 +60,39 @@ def forward( freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim - cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) - sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) - - if self.dim % num_rope_elems != 0: - cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) - cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) - sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) return cos_freqs, sin_freqs @@ -74,6 +105,7 @@ def __init__( attention_head_dim: int, activation_fn: str = "gelu-approximate", eps: float = 1e-6, + rope_type: str = "interleaved", ): super().__init__() @@ -84,6 +116,7 @@ def __init__( kv_heads=num_attention_heads, dim_head=attention_head_dim, processor=LTX2AudioVideoAttnProcessor(), + rope_type=rope_type ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -126,6 +159,7 @@ def __init__( rope_double_precision: bool = True, eps: float = 1e-6, causal_temporal_positioning: bool = False, + rope_type: str = "interleaved" ): super().__init__() self.num_attention_heads = num_attention_heads @@ -139,7 +173,12 @@ def __init__( self.learnable_registers = torch.nn.Parameter(init_registers) self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads ) self.transformer_blocks = torch.nn.ModuleList( @@ -148,6 +187,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + rope_type=rope_type ) for _ in range(num_layers) ] @@ -234,6 +274,7 @@ def __init__( rope_theta: float, rope_double_precision: bool, causal_temporal_positioning: bool, + rope_type: str = "interleaved", ): super().__init__() self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) @@ -246,6 +287,7 @@ def __init__( rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -256,6 +298,7 @@ def __init__( rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type ) def forward( From 2fa4f8471f933dfaec79f074ebe6f96a76752b47 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 00:19:39 +0100 Subject: [PATCH 54/86] When using split RoPE, make sure that the output dtype is same as input dtype --- src/diffusers/models/transformers/transformer_ltx2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2f31054319a8..9c41bf949e02 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -47,6 +47,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = freqs + x_dtype = x.dtype needs_reshape = False if x.ndim != 4 and cos.ndim == 4: # cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head) @@ -61,7 +62,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten r = last // 2 # (..., 2, r) - split_x = x.reshape(*x.shape[:-1], 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float first_x = split_x[..., :1, :] # (..., 1, r) second_x = split_x[..., 1:, :] # (..., 1, r) @@ -80,6 +81,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten if needs_reshape: out = out.swapaxes(1, 2).reshape(b, t, -1) + out = out.to(dtype=x_dtype) return out From bff989110c7695fe928661a849f43fde715c9ee0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 01:22:05 +0100 Subject: [PATCH 55/86] Fix apply split RoPE shape error when reshaping x to 4D --- src/diffusers/models/transformers/transformer_ltx2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 9c41bf949e02..4e3cd84ec71a 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -50,8 +50,10 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten x_dtype = x.dtype needs_reshape = False if x.ndim != 4 and cos.ndim == 4: - # cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head) - b, h, t, _ = cos.shape + # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + # The cos/sin batch dim may only be broadcastable, so take batch size from x + b = x.shape[0] + _, h, t, _ = cos.shape x = x.reshape(b, t, h, -1).swapaxes(1, 2) needs_reshape = True From cb50cacba53e3b412292e491ebee5c3d779bf167 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 02:17:39 +0100 Subject: [PATCH 56/86] Add export_utils file for exporting LTX 2.0 videos with audio --- src/diffusers/pipelines/ltx2/export_utils.py | 134 +++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 + 3 files changed, 140 insertions(+) create mode 100644 src/diffusers/pipelines/ltx2/export_utils.py diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py new file mode 100644 index 000000000000..0bc7a59db228 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -0,0 +1,134 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fractions import Fraction +from typing import Optional + +import torch + +from ...utils import is_av_available + + +_CAN_USE_AV = is_av_available() +if _CAN_USE_AV: + import av +else: + raise ImportError( + "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" + ) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str +) -> None: + video_np = video.cpu().numpy() + + _, height, width, _ = video_np.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6884d3be9292..b160e9925425 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -66,6 +66,7 @@ is_accelerate_version, is_aiter_available, is_aiter_version, + is_av_available, is_better_profanity_available, is_bitsandbytes_available, is_bitsandbytes_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 57b0a337922a..425c360a3110 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_av_available, _av_version = _is_package_available("av") def is_torch_available(): @@ -420,6 +421,10 @@ def is_kornia_available(): return _kornia_available +def is_av_available(): + return _av_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the From 93a417f24a36e4ea5d5e6de69eb9e722acfc0abd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 6 Jan 2026 08:05:30 +0530 Subject: [PATCH 57/86] Tests for T2V and I2V (#6) * add ltx2 pipeline tests. * up * up * up * up * remove content * style * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * up * up * i2v tests. * up * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * revert unneded changes. * up * up * update to split style rope. * up --------- Co-authored-by: Daniel Gu --- scripts/convert_ltx2_to_diffusers.py | 28 +- scripts/ltx2_test_full_pipeline.py | 3 +- scripts/ltx2_test_full_pipeline_i2v.py | 11 +- scripts/test_ltx2_audio_conversion.py | 2 +- src/diffusers/__init__.py | 17 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/autoencoders/__init__.py | 2 +- .../autoencoders/autoencoder_kl_ltx2.py | 17 +- .../autoencoders/autoencoder_kl_ltx2_audio.py | 9 +- .../models/transformers/transformer_ltx2.py | 88 ++++--- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/ltx2/__init__.py | 4 +- src/diffusers/pipelines/ltx2/connectors.py | 29 ++- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 16 +- .../ltx2/pipeline_ltx2_image2video.py | 42 +-- src/diffusers/pipelines/ltx2/vocoder.py | 22 +- src/diffusers/utils/dummy_pt_objects.py | 45 ++++ .../dummy_torch_and_transformers_objects.py | 30 +++ .../test_models_autoencoder_ltx2_video.py | 2 - .../test_models_transformer_ltx2.py | 6 +- tests/pipelines/ltx2/__init__.py | 0 tests/pipelines/ltx2/test_ltx2.py | 239 +++++++++++++++++ tests/pipelines/ltx2/test_ltx2_image2video.py | 241 ++++++++++++++++++ 23 files changed, 725 insertions(+), 134 deletions(-) create mode 100644 tests/pipelines/ltx2/__init__.py create mode 100644 tests/pipelines/ltx2/test_ltx2.py create mode 100644 tests/pipelines/ltx2/test_ltx2_image2video.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 3b8c9598b513..fa461979785e 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -1,5 +1,4 @@ import argparse -import math import os from contextlib import nullcontext from typing import Any, Dict, Optional, Tuple @@ -8,9 +7,15 @@ import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration - -from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available @@ -186,7 +191,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "num_attention_heads": 2, "attention_head_dim": 8, "cross_attention_dim": 16, - "vae_scale_factors": (8, 32 ,32), + "vae_scale_factors": (8, 32, 32), "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, @@ -229,7 +234,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "num_attention_heads": 32, "attention_head_dim": 128, "cross_attention_dim": 4096, - "vae_scale_factors": (8, 32 ,32), + "vae_scale_factors": (8, 32, 32), "pos_embed_max_pos": 20, "base_height": 2048, "base_width": 2048, @@ -257,7 +262,7 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "causal_offset": 1, "timestep_scale_multiplier": 1000, "cross_attn_timestep_scale_multiplier": 1000, - "rope_type": "split" + "rope_type": "split", }, } rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT @@ -307,7 +312,7 @@ def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, "rope_type": "split", }, } - + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT special_keys_remap = {} @@ -541,7 +546,7 @@ def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "leaky_relu_negative_slope": 0.1, "output_sampling_rate": 24000, - } + }, } rename_dict = LTX_2_0_VOCODER_RENAME_DICT special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP @@ -574,7 +579,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder - def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -757,7 +761,7 @@ def main(args): transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) if not args.full_pipeline: transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) - + if args.connectors or args.full_pipeline: if args.dit_filename is not None: original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) @@ -810,6 +814,6 @@ def main(args): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() main(args) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 5f0f366e714e..16ea9f80404f 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -1,5 +1,4 @@ import argparse -import math import os from fractions import Fraction from typing import Optional @@ -211,6 +210,6 @@ def main(args): ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py index 01b18e5eb814..8c39647eae88 100644 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -1,12 +1,11 @@ - import argparse import os from fractions import Fraction from typing import Optional -from PIL import Image import av # Needs to be installed separately (`pip install av`) import torch +from PIL import Image from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline @@ -131,7 +130,7 @@ def parse_args(): parser.add_argument( "--negative_prompt", type=str, - default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.", ) parser.add_argument("--num_inference_steps", type=int, default=40) @@ -166,7 +165,9 @@ def parse_args(): def main(args): pipeline = LTX2ImageToVideoPipeline.from_pretrained( - args.model_id, revision=args.revision, torch_dtype=args.dtype, + args.model_id, + revision=args.revision, + torch_dtype=args.dtype, ) if args.cpu_offload: pipeline.enable_model_cpu_offload() @@ -201,6 +202,6 @@ def main(args): ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py index a6ba16ed9efa..3aa2a65d3f16 100644 --- a/scripts/test_ltx2_audio_conversion.py +++ b/scripts/test_ltx2_audio_conversion.py @@ -90,7 +90,7 @@ def main() -> None: latent_width, device=device, dtype=dtype, - generator=torch.Generator(device).manual_seed(42) + generator=torch.Generator(device).manual_seed(42), ) original_out = original_decoder(dummy) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2e99ea8063a4..9c9ade91548b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -193,9 +193,9 @@ "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", - "AutoencoderKLLTXVideo", "AutoencoderKLLTX2Audio", "AutoencoderKLLTX2Video", + "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", @@ -237,8 +237,8 @@ "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", - "LTXVideoTransformer3DModel", "LTX2VideoTransformer3DModel", + "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -533,12 +533,13 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTX2ImageToVideoPipeline", + "LTX2Pipeline", + "LTX2Pipeline", "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXLatentUpsamplePipeline", "LTXPipeline", - "LTX2Pipeline", - "LTX2ImageToVideoPipeline", "LucyEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", @@ -931,9 +932,9 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, - AutoencoderKLLTXVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -975,8 +976,8 @@ Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, - LTXVideoTransformer3DModel, LTX2VideoTransformer3DModel, + LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, @@ -1241,12 +1242,12 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTX2ImageToVideoPipeline, + LTX2Pipeline, LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, - LTX2Pipeline, - LTX2ImageToVideoPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d3bcb3bcee7a..4d372e1112a0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -154,9 +154,9 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, - AutoencoderKLLTXVideo, AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, @@ -213,8 +213,8 @@ HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, - LTXVideoTransformer3DModel, LTX2VideoTransformer3DModel, + LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 38d52f0eb5e7..8e7a9c81d2ad 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,8 +10,8 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo -from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index df59e2d74868..2d55f166c6fd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -25,7 +25,6 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution @@ -33,8 +32,8 @@ class PerChannelRMSNorm(nn.Module): """ Per-pixel (per-location) RMS normalization layer. - For each element along the chosen dimension, this layer normalizes the tensor - by the root-mean-square of its values across that dimension: + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) """ @@ -174,9 +173,7 @@ def __init__( if in_channels != out_channels: self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d - self.conv_shortcut = nn.Conv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 - ) + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -953,7 +950,10 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, ) -> torch.Tensor: causal = causal or self.is_causal @@ -1279,7 +1279,8 @@ def decode( if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 8cdcfa1a74c5..091d55645a5d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -249,6 +249,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + class LTX2AudioAudioPatchifier: """ Patchifier for spectrogram/audio latents. @@ -405,9 +406,7 @@ def __init__( final_block_channels = block_in if self.norm_type == "group": - self.norm_out = nn.GroupNorm( - num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True - ) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) elif self.norm_type == "pixel": self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) else: @@ -538,8 +537,8 @@ def __init__( # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over # the entire dataset and stored in model's checkpoint under AudioVAE state_dict - latents_std = torch.zeros((base_channels, )) - latents_mean = torch.ones((base_channels, )) + latents_std = torch.zeros((base_channels,)) + latents_mean = torch.ones((base_channels,)) self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 4e3cd84ec71a..2182a59cd093 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -22,16 +22,21 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, PixArtAlphaCombinedTimestepSizeEmbeddings -from ..modeling_outputs import Transformer2DModelOutput +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle, RMSNorm +from ..normalization import RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +49,7 @@ def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, tor out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out + def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = freqs @@ -65,7 +71,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten # (..., 2, r) split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float - first_x = split_x[..., :1, :] # (..., 1, r) + first_x = split_x[..., :1, :] # (..., 1, r) second_x = split_x[..., 1:, :] # (..., 1, r) cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) @@ -89,6 +95,7 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb} + @dataclass class AudioVisualModelOutput(BaseOutput): r""" @@ -192,7 +199,9 @@ def __call__( if query_rotary_emb is not None: query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb) - key = ROTARY_FN_MAP[attn.rope_type](key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + key = ROTARY_FN_MAP[attn.rope_type]( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) @@ -368,7 +377,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -381,7 +390,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -396,7 +405,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - rope_type=rope_type + rope_type=rope_type, ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -481,7 +490,9 @@ def forward( audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( batch_size, temb_audio.size(1), num_audio_ada_params, -1 ) - audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = audio_ada_values.unbind(dim=2) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa attn_audio_hidden_states = self.audio_attn1( @@ -550,8 +561,12 @@ def forward( if use_a2v_cross_attn: # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale.squeeze(2)) + audio_a2v_ca_shift.squeeze(2) + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_a2v_ca_scale.squeeze(2) + ) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) a2v_attn_hidden_states = self.audio_to_video_attn( mod_norm_hidden_states, @@ -565,8 +580,12 @@ def forward( if use_v2a_cross_attn: # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_v2a_ca_scale.squeeze(2)) + audio_v2a_ca_shift.squeeze(2) + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_v2a_ca_scale.squeeze(2) + ) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) v2a_attn_hidden_states = self.video_to_audio_attn( mod_norm_audio_hidden_states, @@ -596,9 +615,10 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): Args: causal_offset (`int`, *optional*, defaults to `1`): - Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where - the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling). + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). """ + def __init__( self, dim: int, @@ -658,9 +678,9 @@ def prepare_video_coords( fps: float = 25.0, ) -> torch.Tensor: """ - Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original - pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, - num_patches, 2) where + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) - axis 3 (size 2) stores `[start, end)` indices within each dimension @@ -727,8 +747,8 @@ def prepare_audio_coords( shift: int = 0, ) -> torch.Tensor: """ - Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent - frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 1) represents the temporal dimension - axis 3 (size 2) stores `[start, end)` indices within each dimension @@ -763,7 +783,7 @@ def prepare_audio_coords( # Handle first frame causal offset, ensuring non-negative timestamps grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) # Convert mel bins back into seconds - grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor @@ -862,7 +882,7 @@ def forward( sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - + elif self.rope_type == "split": expected_freqs = self.dim // 2 current_freqs = freqs.shape[-1] @@ -1087,7 +1107,7 @@ def __init__( modality="audio", double_precision=rope_double_precision, rope_type=rope_type, - num_attention_heads=audio_num_attention_heads + num_attention_heads=audio_num_attention_heads, ) # 5. Transformer Blocks @@ -1154,7 +1174,7 @@ def forward( encoder_hidden_states (`torch.Tensor`): Input text embeddings of shape TODO. TODO for the rest. - + Returns: `AudioVisualModelOutput` or `tuple`: If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a @@ -1204,14 +1224,18 @@ def forward( audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :], device=audio_hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) # 2. Patchify input projections hidden_states = self.proj_in(hidden_states) audio_hidden_states = self.audio_proj_in(audio_hidden_states) # 3. Prepare timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer @@ -1243,7 +1267,9 @@ def forward( batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(batch_size, -1, video_cross_attn_scale_shift.shape[-1]) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( @@ -1256,7 +1282,9 @@ def forward( batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(batch_size, -1, audio_cross_attn_scale_shift.shape[-1]) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) # 4. Prepare prompt embeddings diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index eaf444d5ec37..39c8ce662322 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -720,7 +720,7 @@ LEditsPPPipelineStableDiffusionXL, ) from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline - from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline + from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 95d5f8d4a445..2760f8f7feeb 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -22,9 +22,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] - _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,9 +35,9 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .connectors import LTX2TextConnectors from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline - from .connectors import LTX2TextConnectors from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index c146c9833e71..2608c2783f7e 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -9,6 +9,7 @@ from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + class LTX2RotaryPosEmbed1d(nn.Module): """ 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. @@ -21,12 +22,12 @@ def __init__( theta: float = 10000.0, double_precision: bool = True, rope_type: str = "interleaved", - num_attention_heads: int = 32 + num_attention_heads: int = 32, ): super().__init__() if rope_type not in ["interleaved", "split"]: raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") - + self.dim = dim self.base_seq_len = base_seq_len self.theta = theta @@ -69,7 +70,7 @@ def forward( sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - + elif self.rope_type == "split": expected_freqs = self.dim // 2 current_freqs = freqs.shape[-1] @@ -116,7 +117,7 @@ def __init__( kv_heads=num_attention_heads, dim_head=attention_head_dim, processor=LTX2AudioVideoAttnProcessor(), - rope_type=rope_type + rope_type=rope_type, ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -159,7 +160,7 @@ def __init__( rope_double_precision: bool = True, eps: float = 1e-6, causal_temporal_positioning: bool = False, - rope_type: str = "interleaved" + rope_type: str = "interleaved", ): super().__init__() self.num_attention_heads = num_attention_heads @@ -173,12 +174,12 @@ def __init__( self.learnable_registers = torch.nn.Parameter(init_registers) self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, - base_seq_len=rope_base_seq_len, - theta=rope_theta, + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, double_precision=rope_double_precision, rope_type=rope_type, - num_attention_heads=num_attention_heads + num_attention_heads=num_attention_heads, ) self.transformer_blocks = torch.nn.ModuleList( @@ -187,7 +188,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - rope_type=rope_type + rope_type=rope_type, ) for _ in range(num_layers) ] @@ -253,8 +254,8 @@ def forward( class LTX2TextConnectors(ModelMixin, ConfigMixin): """ - Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and - audio streams. + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. """ @register_to_config @@ -287,7 +288,7 @@ def __init__( rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type + rope_type=rope_type, ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -298,7 +299,7 @@ def __init__( rope_theta=rope_theta, rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type + rope_type=rope_type, ) def forward( diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 08fad91c4188..7cbcca67d2c6 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -674,7 +674,9 @@ def prepare_audio_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) latent_length = int(duration_s * latents_per_second) if latents is not None: @@ -995,7 +997,9 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype) - audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -1026,10 +1030,14 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) if self.guidance_rescale > 0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index caad9a1767d3..0a707806ce1b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -13,25 +13,26 @@ # limitations under the License. import copy -from typing import Any, Callable, Dict, List, Optional, Union import inspect +from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast -from ...schedulers import FlowMatchEulerDiscreteScheduler from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from ..pipeline_utils import DiffusionPipeline from .vocoder import LTX2Vocoder -from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video -from ...models.transformers import LTX2VideoTransformer3DModel -from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast -from ...video_processor import VideoProcessor if is_torch_xla_available(): @@ -86,6 +87,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -665,7 +667,7 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_frames, height, width) mask_shape = (batch_size, 1, num_frames, height, width) - + if latents is not None: conditioning_mask = latents.new_zeros(mask_shape) conditioning_mask[:, :, 0] = 1.0 @@ -697,7 +699,7 @@ def prepare_latents( init_latents = torch.cat(init_latents, dim=0).to(dtype) init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) - + # First condition is image latents and those should be kept clean. conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) conditioning_mask[:, :, 0] = 1.0 @@ -731,7 +733,9 @@ def prepare_audio_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: duration_s = num_frames / frame_rate - latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) latent_length = int(duration_s * latents_per_second) if latents is not None: @@ -982,7 +986,7 @@ def __call__( ) if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) - + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio @@ -1063,12 +1067,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype) - audio_latent_model_input = torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) timestep = t.expand(latent_model_input.shape[0]) video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - + with self.transformer.cache_context("cond_uncond"): noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, @@ -1095,10 +1101,14 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (noise_pred_video_text - noise_pred_video_uncond) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) if self.guidance_rescale > 0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py index c3b3c1f36796..217c68103e39 100644 --- a/src/diffusers/pipelines/ltx2/vocoder.py +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -25,32 +25,18 @@ def __init__( self.convs1 = nn.ModuleList( [ - nn.Conv1d( - channels, - channels, - kernel_size, - stride=stride, - dilation=dilation, - padding=padding_mode - ) + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) for dilation in dilations ] ) self.convs2 = nn.ModuleList( [ - nn.Conv1d( - channels, - channels, - kernel_size, - stride=stride, - dilation=1, - padding=padding_mode - ) + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) for _ in range(len(dilations)) ] ) - + def forward(self, x: torch.Tensor) -> torch.Tensor: for conv1, conv2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, negative_slope=self.negative_slope) @@ -127,7 +113,7 @@ def __init__( input_channels = output_channels self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) - + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: r""" Forward pass of the vocoder. diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8628893200fe..54746ecb5815 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -502,6 +502,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLLTX2Audio(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLLTX2Video(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -1132,6 +1162,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTX2VideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da64742518bb..50a88afbb218 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1802,6 +1802,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 25984d621ac0..146241361a82 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -15,8 +15,6 @@ import unittest -import torch - from diffusers import AutoencoderKLLTX2Video from ...testing_utils import ( diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 1b0a7dd28f26..8a6b50b55eea 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -52,9 +52,9 @@ def dummy_input(self): sequence_length = 16 hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) - audio_hidden_states = torch.randn( - (batch_size, audio_num_frames, audio_num_channels * num_mel_bins) - ).to(torch_device) + audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( + torch_device + ) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) diff --git a/tests/pipelines/ltx2/__init__.py b/tests/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py new file mode 100644 index 000000000000..73d08e6b1a20 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "output_type", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py new file mode 100644 index 000000000000..9c58b4fc413d --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -0,0 +1,241 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2ImageToVideoPipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.rand((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) From c039c87b99612c87573aed5805b8da3dc35ac1f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 08:09:59 +0530 Subject: [PATCH 58/86] up --- src/diffusers/utils/dummy_pt_objects.py | 15 ++++++++++ .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 619759cf36ee..d05486d56cbf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1177,6 +1177,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTX2VideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6c28e87581b9..58b93428f3db 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1877,6 +1877,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 550eca353041ef3b54170f16f35e0f1e68588a42 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 09:14:38 +0530 Subject: [PATCH 59/86] use export util funcs. --- scripts/ltx2_test_full_pipeline.py | 109 +------------------------ scripts/ltx2_test_full_pipeline_i2v.py | 109 +------------------------ 2 files changed, 2 insertions(+), 216 deletions(-) diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py index 16ea9f80404f..20eb65fde25c 100644 --- a/scripts/ltx2_test_full_pipeline.py +++ b/scripts/ltx2_test_full_pipeline.py @@ -1,117 +1,10 @@ import argparse import os -from fractions import Fraction -from typing import Optional -import av # Needs to be installed separately (`pip install av`) import torch from diffusers import LTX2Pipeline - - -# Video export functions copied from original LTX 2.0 code -def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: - """ - Prepare the audio stream for writing. - """ - audio_stream = container.add_stream("aac", rate=audio_sample_rate) - audio_stream.codec_context.sample_rate = audio_sample_rate - audio_stream.codec_context.layout = "stereo" - audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) - return audio_stream - - -def _resample_audio( - container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame -) -> None: - cc = audio_stream.codec_context - - # Use the encoder's format/layout/rate as the *target* - target_format = cc.format or "fltp" # AAC → usually fltp - target_layout = cc.layout or "stereo" - target_rate = cc.sample_rate or frame_in.sample_rate - - audio_resampler = av.audio.resampler.AudioResampler( - format=target_format, - layout=target_layout, - rate=target_rate, - ) - - audio_next_pts = 0 - for rframe in audio_resampler.resample(frame_in): - if rframe.pts is None: - rframe.pts = audio_next_pts - audio_next_pts += rframe.samples - rframe.sample_rate = frame_in.sample_rate - container.mux(audio_stream.encode(rframe)) - - # flush audio encoder - for packet in audio_stream.encode(): - container.mux(packet) - - -def _write_audio( - container: av.container.Container, - audio_stream: av.audio.AudioStream, - samples: torch.Tensor, - audio_sample_rate: int, -) -> None: - if samples.ndim == 1: - samples = samples[:, None] - - if samples.shape[1] != 2 and samples.shape[0] == 2: - samples = samples.T - - if samples.shape[1] != 2: - raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") - - # Convert to int16 packed for ingestion; resampler converts to encoder fmt. - if samples.dtype != torch.int16: - samples = torch.clip(samples, -1.0, 1.0) - samples = (samples * 32767.0).to(torch.int16) - - frame_in = av.AudioFrame.from_ndarray( - samples.contiguous().reshape(1, -1).cpu().numpy(), - format="s16", - layout="stereo", - ) - frame_in.sample_rate = audio_sample_rate - - _resample_audio(container, audio_stream, frame_in) - - -def encode_video( - video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str -) -> None: - video_np = video.cpu().numpy() - - _, height, width, _ = video_np.shape - - container = av.open(output_path, mode="w") - stream = container.add_stream("libx264", rate=int(fps)) - stream.width = width - stream.height = height - stream.pix_fmt = "yuv420p" - - if audio is not None: - if audio_sample_rate is None: - raise ValueError("audio_sample_rate is required when audio is provided") - - audio_stream = _prepare_audio_stream(container, audio_sample_rate) - - for frame_array in video_np: - frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) - - # Flush encoder - for packet in stream.encode(): - container.mux(packet) - - if audio is not None: - _write_audio(container, audio_stream, audio, audio_sample_rate) - - container.close() +from diffusers.pipelines.ltx2.export_utils import encode_video def parse_args(): diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py index 8c39647eae88..f99d8c135e19 100644 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -1,118 +1,11 @@ import argparse import os -from fractions import Fraction -from typing import Optional -import av # Needs to be installed separately (`pip install av`) import torch from PIL import Image from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline - - -# Video export functions copied from original LTX 2.0 code -def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: - """ - Prepare the audio stream for writing. - """ - audio_stream = container.add_stream("aac", rate=audio_sample_rate) - audio_stream.codec_context.sample_rate = audio_sample_rate - audio_stream.codec_context.layout = "stereo" - audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) - return audio_stream - - -def _resample_audio( - container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame -) -> None: - cc = audio_stream.codec_context - - # Use the encoder's format/layout/rate as the *target* - target_format = cc.format or "fltp" # AAC → usually fltp - target_layout = cc.layout or "stereo" - target_rate = cc.sample_rate or frame_in.sample_rate - - audio_resampler = av.audio.resampler.AudioResampler( - format=target_format, - layout=target_layout, - rate=target_rate, - ) - - audio_next_pts = 0 - for rframe in audio_resampler.resample(frame_in): - if rframe.pts is None: - rframe.pts = audio_next_pts - audio_next_pts += rframe.samples - rframe.sample_rate = frame_in.sample_rate - container.mux(audio_stream.encode(rframe)) - - # flush audio encoder - for packet in audio_stream.encode(): - container.mux(packet) - - -def _write_audio( - container: av.container.Container, - audio_stream: av.audio.AudioStream, - samples: torch.Tensor, - audio_sample_rate: int, -) -> None: - if samples.ndim == 1: - samples = samples[:, None] - - if samples.shape[1] != 2 and samples.shape[0] == 2: - samples = samples.T - - if samples.shape[1] != 2: - raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") - - # Convert to int16 packed for ingestion; resampler converts to encoder fmt. - if samples.dtype != torch.int16: - samples = torch.clip(samples, -1.0, 1.0) - samples = (samples * 32767.0).to(torch.int16) - - frame_in = av.AudioFrame.from_ndarray( - samples.contiguous().reshape(1, -1).cpu().numpy(), - format="s16", - layout="stereo", - ) - frame_in.sample_rate = audio_sample_rate - - _resample_audio(container, audio_stream, frame_in) - - -def encode_video( - video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str -) -> None: - video_np = video.cpu().numpy() - - _, height, width, _ = video_np.shape - - container = av.open(output_path, mode="w") - stream = container.add_stream("libx264", rate=int(fps)) - stream.width = width - stream.height = height - stream.pix_fmt = "yuv420p" - - if audio is not None: - if audio_sample_rate is None: - raise ValueError("audio_sample_rate is required when audio is provided") - - audio_stream = _prepare_audio_stream(container, audio_sample_rate) - - for frame_array in video_np: - frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) - - # Flush encoder - for packet in stream.encode(): - container.mux(packet) - - if audio is not None: - _write_audio(container, audio_stream, audio, audio_sample_rate) - - container.close() +from diffusers.pipelines.ltx2.export_utils import encode_video def parse_args(): From ef199118e2c8fc0852369cf2394549c1fadaf21e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 06:35:51 +0100 Subject: [PATCH 60/86] Point original checkpoint to LTX 2.0 official checkpoint --- scripts/convert_ltx2_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index fa461979785e..4ec654d9d7ba 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -635,7 +635,7 @@ def get_args(): parser.add_argument( "--original_state_dict_repo_id", - default="diffusers-internal-dev/new-ltx-model", + default="Lightricks/LTX-2", type=str, help="HF Hub repo id with LTX 2.0 checkpoint", ) @@ -655,7 +655,7 @@ def get_args(): parser.add_argument( "--combined_filename", - default="ltx-av-step-1932500-interleaved-new-vae.safetensors", + default="ltx-2-19b-dev.safetensors", type=str, help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", ) From ace2ee93fbe47d6d095e90e4e90a74bcf72a9823 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 06:40:42 +0100 Subject: [PATCH 61/86] Allow the I2V pipeline to accept image URLs --- scripts/ltx2_test_full_pipeline_i2v.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py index f99d8c135e19..bb2d05bd02e5 100644 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -6,6 +6,7 @@ from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image def parse_args(): @@ -67,8 +68,10 @@ def main(args): else: pipeline.to(device=args.device) + image = load_image(args.image_path) + video, audio = pipeline( - image=Image.open(args.image_path), + image=image, prompt=args.prompt, negative_prompt=args.negative_prompt, height=args.height, From dd81242ebadb2218410ed5f22924ea214240f1de Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 06:42:24 +0100 Subject: [PATCH 62/86] make style and make quality --- scripts/ltx2_test_full_pipeline_i2v.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py index bb2d05bd02e5..cbe61eecdf7c 100644 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ b/scripts/ltx2_test_full_pipeline_i2v.py @@ -2,7 +2,6 @@ import os import torch -from PIL import Image from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline from diffusers.pipelines.ltx2.export_utils import encode_video From 57ead0b5e5218a5bca29f7e65be94ea45424a809 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 20:48:16 +0530 Subject: [PATCH 63/86] remove function map. --- .../models/transformers/transformer_ltx2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 2182a59cd093..8dcd8a00509e 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -93,9 +93,6 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten return out -ROTARY_FN_MAP = {"interleaved": apply_interleaved_rotary_emb, "split": apply_split_rotary_emb} - - @dataclass class AudioVisualModelOutput(BaseOutput): r""" @@ -198,10 +195,14 @@ def __call__( key = attn.norm_k(key) if query_rotary_emb is not None: - query = ROTARY_FN_MAP[attn.rope_type](query, query_rotary_emb) - key = ROTARY_FN_MAP[attn.rope_type]( - key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb - ) + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) From c39f1b87a4b98109f9381ae5691b3b16b5940cc5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 20:52:49 +0530 Subject: [PATCH 64/86] remove args. --- .../models/transformers/transformer_ltx2.py | 203 +++++++++--------- 1 file changed, 97 insertions(+), 106 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 8dcd8a00509e..ab369ebfe472 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -459,49 +459,43 @@ def forward( audio_encoder_attention_mask: Optional[torch.Tensor] = None, a2v_cross_attention_mask: Optional[torch.Tensor] = None, v2a_cross_attention_mask: Optional[torch.Tensor] = None, - use_video_self_attn: bool = True, - use_audio_self_attn: bool = True, - use_a2v_cross_attn: bool = True, - use_v2a_cross_attn: bool = True, ) -> torch.Tensor: batch_size = hidden_states.size(0) # 1. Video and Audio Self-Attention - if use_video_self_attn: - norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = self.norm1(hidden_states) - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( - batch_size, temb.size(1), num_ada_params, -1 - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=video_rotary_emb, - ) - hidden_states = hidden_states + attn_hidden_states * gate_msa + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa - if use_audio_self_attn: - norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) - num_audio_ada_params = self.audio_scale_shift_table.shape[0] - audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( - batch_size, temb_audio.size(1), num_audio_ada_params, -1 - ) - audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( - audio_ada_values.unbind(dim=2) - ) - norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=audio_rotary_emb, - ) - audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa # 2. Video and Audio Cross-Attention with the text embeddings norm_hidden_states = self.norm2(hidden_states) @@ -523,80 +517,77 @@ def forward( audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention - if use_a2v_cross_attn or use_v2a_cross_attn: - norm_hidden_states = self.audio_to_video_norm(hidden_states) - norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - - # Combine global and per-layer cross attention modulation parameters - # Video - video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] - video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - - video_ca_scale_shift_table = ( - video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) - + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - video_ca_gate = ( - video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) - + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) - ).unbind(dim=2) - - video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table - a2v_gate = video_ca_gate[0].squeeze(2) - - # Audio - audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] - audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] - - audio_ca_scale_shift_table = ( - audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) - + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - audio_ca_gate = ( - audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) - + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) - ).unbind(dim=2) - - audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table - v2a_gate = audio_ca_gate[0].squeeze(2) - - if use_a2v_cross_attn: - # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * ( - 1 + video_a2v_ca_scale.squeeze(2) - ) + video_a2v_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_a2v_ca_scale.squeeze(2) - ) + audio_a2v_ca_shift.squeeze(2) - - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - query_rotary_emb=ca_video_rotary_emb, - key_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) - hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states - - if use_v2a_cross_attn: - # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * ( - 1 + video_v2a_ca_scale.squeeze(2) - ) + video_v2a_ca_shift.squeeze(2) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_v2a_ca_scale.squeeze(2) - ) + audio_v2a_ca_shift.squeeze(2) - - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states, - encoder_hidden_states=mod_norm_hidden_states, - query_rotary_emb=ca_audio_rotary_emb, - key_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) - audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp From bdcf23ec17af159d9b5f2e9729118af492343467 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 21:02:18 +0530 Subject: [PATCH 65/86] update docs. --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 34 +++++++++++++------ .../ltx2/pipeline_ltx2_image2video.py | 31 ++++++++++++----- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 7cbcca67d2c6..103e324e1193 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -47,24 +47,36 @@ Examples: ```py >>> import torch - >>> from diffusers import LTXPipeline - >>> from diffusers.utils import export_to_video + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video - >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" - >>> video = pipe( + >>> frame_rate = 24.0 + >>> video, audio = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, - ... width=704, - ... height=480, - ... num_frames=161, - ... num_inference_steps=50, - ... ).frames[0] - >>> export_to_video(video, "output.mp4", fps=24) + ... width=768, + ... height=512, + ... frame_rate=frame_rate, + ... num_frames=121, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) ``` """ diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 0a707806ce1b..df0faa2b07f1 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -48,11 +48,12 @@ Examples: ```py >>> import torch - >>> from diffusers import LTX2ImageToVideoPipeline - >>> from diffusers.utils import export_to_video, load_image + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image - >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() >>> image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" @@ -60,16 +61,28 @@ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + >>> frame_rate = 24.0 >>> video = pipe( ... image=image, ... prompt=prompt, ... negative_prompt=negative_prompt, - ... width=704, - ... height=480, + ... width=768, + ... height=512, ... num_frames=121, - ... num_inference_steps=40, - ... ).frames[0] - >>> export_to_video(video, "output.mp4", fps=24) + ... frame_rate=frame_rate, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) ``` """ From 61e0fb4bd85717f50cf6f4bf806860816609f799 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 21:15:47 +0530 Subject: [PATCH 66/86] update doc entries. --- docs/source/en/_toctree.yml | 8 ++++ .../api/models/autoencoderkl_audio_ltx_2.md | 28 ++++++++++++++ .../en/api/models/autoencoderkl_ltx_2.md | 29 +++++++++++++++ .../en/api/models/ltx2_video_transformer3d.md | 26 +++++++++++++ docs/source/en/api/pipelines/ltx2.md | 37 +++++++++++++++++++ .../autoencoders/autoencoder_kl_ltx2.py | 2 +- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 6 +-- .../ltx2/pipeline_ltx2_image2video.py | 6 +-- 8 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/api/models/autoencoderkl_audio_ltx_2.md create mode 100644 docs/source/en/api/models/autoencoderkl_ltx_2.md create mode 100644 docs/source/en/api/models/ltx2_video_transformer3d.md create mode 100644 docs/source/en/api/pipelines/ltx2.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f0cb0164436e..ed5f01a0250d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -367,6 +367,8 @@ title: LatteTransformer3DModel - local: api/models/longcat_image_transformer2d title: LongCatImageTransformer2DModel + - local: api/models/ltx2_video_transformer3d + title: LTX2VideoTransformer3DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel - local: api/models/lumina2_transformer2d @@ -443,6 +445,10 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoder_kl_hunyuan_video15 title: AutoencoderKLHunyuanVideo15 + - local: api/models/autoencoderkl_audio_ltx_2 + title: AutoencoderKLLTX2Audio + - local: api/models/autoencoderkl_ltx_2 + title: AutoencoderKLLTX2Video - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_magvit @@ -678,6 +684,8 @@ title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte + - local: api/pipelines/ltx2 + title: LTX-2 - local: api/pipelines/ltx_video title: LTXVideo - local: api/pipelines/mochi diff --git a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md new file mode 100644 index 000000000000..36d1834e58c1 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md @@ -0,0 +1,28 @@ + + +# AutoencoderKLLTX2Audio + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Audio + +vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Audio + +[[autodoc]] AutoencoderKLLTX2Audio + - decode + - all \ No newline at end of file diff --git a/docs/source/en/api/models/autoencoderkl_ltx_2.md b/docs/source/en/api/models/autoencoderkl_ltx_2.md new file mode 100644 index 000000000000..1dbf516c017a --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_ltx_2.md @@ -0,0 +1,29 @@ + + +# AutoencoderKLLTX2Video + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Video + +vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Video + +[[autodoc]] AutoencoderKLLTX2Video + - decode + - encode + - all diff --git a/docs/source/en/api/models/ltx2_video_transformer3d.md b/docs/source/en/api/models/ltx2_video_transformer3d.md new file mode 100644 index 000000000000..9faab8695468 --- /dev/null +++ b/docs/source/en/api/models/ltx2_video_transformer3d.md @@ -0,0 +1,26 @@ + + +# LTX2VideoTransformer3DModel + +A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import LTX2VideoTransformer3DModel + +transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## LTX2VideoTransformer3DModel + +[[autodoc]] LTX2VideoTransformer3DModel diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md new file mode 100644 index 000000000000..c71def1ab716 --- /dev/null +++ b/docs/source/en/api/pipelines/ltx2.md @@ -0,0 +1,37 @@ + + +# LTX-2 + +LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. + +You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization. + +The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). + +## LTX2Pipeline + +[[autodoc]] LTX2Pipeline + - all + - __call__ + +## LTX2ImageToVideoPipeline + +[[autodoc]] LTX2ImageToVideoPipeline + - all + - __call__ + +## LTX2PipelineOutput + +[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 2d55f166c6fd..01dd55a938b6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -1004,7 +1004,7 @@ def forward( class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in - [LTX](https://huggingface.co/Lightricks/LTX-Video). + [LTX-2](https://huggingface.co/Lightricks/LTX-2). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 103e324e1193..e95b8d5c0b2e 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -833,7 +833,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -853,8 +853,8 @@ def __call__( Examples: Returns: - [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index df0faa2b07f1..5a4b27280958 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -896,7 +896,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -916,8 +916,8 @@ def __call__( Examples: Returns: - [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ From 8c5ab1fd6d6b014682653353590d9c0a9db1cbd3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 Jan 2026 21:31:29 +0530 Subject: [PATCH 67/86] disable ltx2_consistency test --- .../test_models_transformer_ltx2.py | 212 +++++++++--------- 1 file changed, 106 insertions(+), 106 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 8a6b50b55eea..af9ef0623891 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -17,7 +17,7 @@ import torch -from diffusers import LTX2VideoTransformer3DModel, attention_backend +from diffusers import LTX2VideoTransformer3DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -108,111 +108,111 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"LTX2VideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def test_ltx2_consistency(self, seed=0, dtype=torch.float32): - torch.manual_seed(seed) - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - # Calculate dummy inputs in a custom manner to ensure compatibility with original code - batch_size = 2 - num_frames = 9 - latent_frames = 2 - text_embedding_dim = 16 - text_seq_len = 16 - fps = 25.0 - sampling_rate = 16000.0 - hop_length = 160.0 - - sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 - timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) - - num_channels = 4 - latent_height = 4 - latent_width = 4 - hidden_states = torch.randn( - (batch_size, num_channels, latent_frames, latent_height, latent_width), - generator=torch.manual_seed(seed), - dtype=dtype, - device="cpu", - ) - # Patchify video latents (with patch_size (1, 1, 1)) - hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) - hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - encoder_hidden_states = torch.randn( - (batch_size, text_seq_len, text_embedding_dim), - generator=torch.manual_seed(seed), - dtype=dtype, - device="cpu", - ) - - audio_num_channels = 2 - num_mel_bins = 2 - latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) - audio_hidden_states = torch.randn( - (batch_size, audio_num_channels, latent_length, num_mel_bins), - generator=torch.manual_seed(seed), - dtype=dtype, - device="cpu", - ) - # Patchify audio latents - audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) - audio_encoder_hidden_states = torch.randn( - (batch_size, text_seq_len, text_embedding_dim), - generator=torch.manual_seed(seed), - dtype=dtype, - device="cpu", - ) - - inputs_dict = { - "hidden_states": hidden_states.to(device=torch_device), - "audio_hidden_states": audio_hidden_states.to(device=torch_device), - "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), - "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), - "timestep": timestep, - "num_frames": latent_frames, - "height": latent_height, - "width": latent_width, - "audio_num_frames": num_frames, - "fps": 25.0, - } - - model = self.model_class.from_pretrained( - "diffusers-internal-dev/dummy-ltx2", - subfolder="transformer", - device_map="cpu", - ) - # torch.manual_seed(seed) - # model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with attention_backend("native"): - with torch.no_grad(): - output = model(**inputs_dict) - - video_output, audio_output = output.to_tuple() - - self.assertIsNotNone(video_output) - self.assertIsNotNone(audio_output) - - # input & output have to have the same shape - video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) - self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") - audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) - self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") - - # Check against expected slice - # fmt: off - video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) - audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) - # fmt: on - - video_output_flat = video_output.cpu().flatten().float() - video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) - self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) - - audio_output_flat = audio_output.cpu().flatten().float() - audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) - self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + # def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + # torch.manual_seed(seed) + # init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # # Calculate dummy inputs in a custom manner to ensure compatibility with original code + # batch_size = 2 + # num_frames = 9 + # latent_frames = 2 + # text_embedding_dim = 16 + # text_seq_len = 16 + # fps = 25.0 + # sampling_rate = 16000.0 + # hop_length = 160.0 + + # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 + # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + # num_channels = 4 + # latent_height = 4 + # latent_width = 4 + # hidden_states = torch.randn( + # (batch_size, num_channels, latent_frames, latent_height, latent_width), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify video latents (with patch_size (1, 1, 1)) + # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + # encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # audio_num_channels = 2 + # num_mel_bins = 2 + # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + # audio_hidden_states = torch.randn( + # (batch_size, audio_num_channels, latent_length, num_mel_bins), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify audio latents + # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + # audio_encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # inputs_dict = { + # "hidden_states": hidden_states.to(device=torch_device), + # "audio_hidden_states": audio_hidden_states.to(device=torch_device), + # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + # "timestep": timestep, + # "num_frames": latent_frames, + # "height": latent_height, + # "width": latent_width, + # "audio_num_frames": num_frames, + # "fps": 25.0, + # } + + # model = self.model_class.from_pretrained( + # "diffusers-internal-dev/dummy-ltx2", + # subfolder="transformer", + # device_map="cpu", + # ) + # # torch.manual_seed(seed) + # # model = self.model_class(**init_dict) + # model.to(torch_device) + # model.eval() + + # with attention_backend("native"): + # with torch.no_grad(): + # output = model(**inputs_dict) + + # video_output, audio_output = output.to_tuple() + + # self.assertIsNotNone(video_output) + # self.assertIsNotNone(audio_output) + + # # input & output have to have the same shape + # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # # Check against expected slice + # # fmt: off + # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # # fmt: on + + # video_output_flat = video_output.cpu().flatten().float() + # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + # audio_output_flat = audio_output.cpu().flatten().float() + # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): From 5e0cf2b2f0f0aff1d8c76aab787be2f88095da7c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 23:32:59 +0100 Subject: [PATCH 68/86] Simplify LTX 2 RoPE forward by removing coords is None logic --- .../models/transformers/transformer_ltx2.py | 47 ++++--------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index ab369ebfe472..413ceb24fab7 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -795,51 +795,22 @@ def prepare_coords(self, *args, **kwargs): def forward( self, - coords: Optional[torch.Tensor] = None, - batch_size: Optional[int] = None, - num_frames: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - fps: float = 25.0, - shift: int = 0, + coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if coords is not None: - device = device or coords.device - batch_size = batch_size or coords.size(0) - else: - device = device or "cpu" - batch_size = batch_size or 1 - - # 1. Calculate the coordinate grid with respect to data space for the given modality (video, audio). - if coords is None and self.modality == "video": - coords = self.prepare_video_coords( - batch_size, - num_frames, - height, - width, - device=device, - fps=fps, - ) - elif coords is None and self.modality == "audio": - coords = self.prepare_audio_coords( - batch_size, - num_frames, - device=device, - shift=shift, - fps=fps, - ) + device = device or coords.device + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) num_pos_dims = coords.shape[1] - # 2. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch # position index if coords.ndim == 4: coords_start, coords_end = coords.chunk(2, dim=-1) coords = (coords_start + coords_end) / 2.0 coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] - # 3. Get coordinates as a fraction of the base data shape + # 2. Get coordinates as a fraction of the base data shape if self.modality == "video": max_positions = (self.base_num_frames, self.base_height, self.base_width) elif self.modality == "audio": @@ -849,7 +820,7 @@ def forward( # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin num_rope_elems = num_pos_dims * 2 - # 4. Create a 1D grid of frequencies for RoPE + # 3. Create a 1D grid of frequencies for RoPE freqs_dtype = torch.float64 if self.double_precision else torch.float32 pow_indices = torch.pow( self.theta, @@ -857,12 +828,12 @@ def forward( ) freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) - # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape # (self.dim // num_elems,) freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] - # 6. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim # TODO: consider implementing this as a utility and reuse in `connectors.py`. # src/diffusers/pipelines/ltx2/connectors.py if self.rope_type == "interleaved": @@ -1212,7 +1183,7 @@ def forward( batch_size, audio_num_frames, audio_hidden_states.device, fps=fps ) - video_rotary_emb = self.rope(video_coords, fps=fps, device=hidden_states.device) + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) From d01a242cdbc30ab761a5e8feac244f769e7f9d81 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 6 Jan 2026 23:54:23 +0100 Subject: [PATCH 69/86] make style and make quality --- src/diffusers/models/transformers/transformer_ltx2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 413ceb24fab7..bf3f3d13c59b 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -794,9 +794,7 @@ def prepare_coords(self, *args, **kwargs): return self.prepare_audio_coords(*args, **kwargs) def forward( - self, - coords: torch.Tensor, - device: Optional[Union[str, torch.device]] = None, + self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or coords.device From 79cf6d7ba451e7c84540b57850ec3f99c2fce9ef Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 04:16:03 +0100 Subject: [PATCH 70/86] Support LTX 2.0 audio VAE encoder --- scripts/convert_ltx2_to_diffusers.py | 6 +- .../autoencoders/autoencoder_kl_ltx2_audio.py | 269 ++++++++++++++++-- .../test_models_autoencoder_kl_ltx2_audio.py | 88 ++++++ 3 files changed, 340 insertions(+), 23 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 4ec654d9d7ba..eb0b010075b4 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -148,10 +148,7 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str "per_channel_statistics.mean-of-stds": remove_keys_inplace, } -LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { - "encoder": remove_keys_inplace, - "per_channel_statistics": convert_ltx2_audio_vae_per_channel_statistics, -} +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} @@ -499,6 +496,7 @@ def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "mel_hop_length": 160, "is_causal": True, "mel_bins": 64, + "double_z": True, }, } rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 091d55645a5d..dc09f44d82c3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Set, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -21,8 +21,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import AutoencoderMixin, DecoderOutput +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution LATENT_DOWNSAMPLE_FACTOR = 4 @@ -219,6 +220,40 @@ def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch return x + h +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + class LTX2AudioUpsample(nn.Module): def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: super().__init__() @@ -282,6 +317,156 @@ def patch_size(self) -> Tuple[int, int, int]: return self._patch_size +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + class LTX2AudioDecoder(nn.Module): """ Symmetric decoder that reconstructs audio spectrograms from latent features. @@ -292,22 +477,22 @@ class LTX2AudioDecoder(nn.Module): def __init__( self, - base_channels: int, - output_channels: int, - num_res_blocks: int, - attn_resolutions: Set[int], - in_channels: int, - resolution: int, - latent_channels: int, - ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), norm_type: str = "group", causality_axis: Optional[str] = "width", dropout: float = 0.0, - mid_block_add_attention: bool = True, + mid_block_add_attention: bool = False, sample_rate: int = 16000, mel_hop_length: int = 160, is_causal: bool = True, - mel_bins: Optional[int] = None, + mel_bins: Optional[int] = 64, ) -> None: super().__init__() @@ -493,9 +678,9 @@ def __init__( self, base_channels: int = 128, output_channels: int = 2, - ch_mult: Tuple[int] = (1, 2, 4), + ch_mult: Tuple[int, ...] = (1, 2, 4), num_res_blocks: int = 2, - attn_resolutions: Optional[Tuple[int]] = None, + attn_resolutions: Optional[Tuple[int, ...]] = None, in_channels: int = 2, resolution: int = 256, latent_channels: int = 8, @@ -507,6 +692,7 @@ def __init__( mel_hop_length: int = 160, is_causal: bool = True, mel_bins: Optional[int] = 64, + double_z: bool = True, ) -> None: super().__init__() @@ -516,6 +702,26 @@ def __init__( attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + self.decoder = LTX2AudioDecoder( base_channels=base_channels, output_channels=output_channels, @@ -548,9 +754,21 @@ def __init__( self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True): - raise NotImplementedError("AutoencoderKLLTX2Audio does not implement encoding.") + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) @@ -568,7 +786,20 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return DecoderOutput(sample=decoded) - def forward(self, *args, **kwargs): - raise NotImplementedError( - "This model doesn't have an encoder yet so we don't implement its `forward()`. Please use `decode()`." - ) + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + print(f"z shape: {z.shape}") + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..3c10330e20aa --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Audio + +from ...testing_utils import ( + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Audio + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 2, # stereo, + "output_channels": 2, + "latent_channels": 4, + "base_channels": 16, + "ch_mult": (1, 2, 4), + "resolution": 16, + "attn_resolutions": None, + "num_res_blocks": 2, + "norm_type": "pixel", + "causality_axis": "height", + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "mel_bins": 16, + "is_causal": True, + "double_z": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 2 + num_frames = 8 + num_mel_bins = 16 + + spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device) + + input_dict = {"sample": spectrogram} + return input_dict + + @property + def input_shape(self): + return (2, 5, 16) + + @property + def output_shape(self): + return (2, 5, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE + def test_output(self): + super().test_output(expected_output_shape=(2, 2, 5, 16)) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXAudio does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass From a17f5cb63f0a4fb1843616e82c674490c04a4d6a Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Tue, 6 Jan 2026 21:34:57 -0800 Subject: [PATCH 71/86] Apply suggestions from code review Co-authored-by: Sayak Paul --- docs/source/en/api/models/autoencoderkl_audio_ltx_2.md | 1 + src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py | 2 +- .../autoencoders/test_models_autoencoder_kl_ltx2_audio.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md index 36d1834e58c1..d0024474e9e0 100644 --- a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md +++ b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md @@ -24,5 +24,6 @@ vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae" ## AutoencoderKLLTX2Audio [[autodoc]] AutoencoderKLLTX2Audio + - encode - decode - all \ No newline at end of file diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index dc09f44d82c3..41b543de2f6e 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -668,7 +668,7 @@ def forward( class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): r""" - LTX2 audio VAE. Currently, only implements the decoder. + LTX2 audio VAE for encoding and decoding audio latent representations. """ _supports_gradient_checkpointing = False diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py index 3c10330e20aa..ce93dfb42afe 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py @@ -83,6 +83,6 @@ def test_output(self): def test_outputs_equivalence(self): pass - @unittest.skip("AutoencoderKLLTXAudio does not support `norm_num_groups` because it does not use GroupNorm.") + @unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.") def test_forward_with_norm_groups(self): pass From 964f1068023a0f03c03c62b7ab613e05ee94878f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 06:37:24 +0100 Subject: [PATCH 72/86] Remove print statement in audio VAE --- src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py index 41b543de2f6e..6c9c7dce3d2f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -798,7 +798,6 @@ def forward( z = posterior.sample(generator=generator) else: z = posterior.mode() - print(f"z shape: {z.shape}") dec = self.decode(z) if not return_dict: return (dec.sample,) From 4dfe509916af3cddf4874c783c65d12120f8c37d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Jan 2026 12:12:30 +0530 Subject: [PATCH 73/86] up --- tests/pipelines/ltx2/test_ltx2.py | 2 +- tests/pipelines/ltx2/test_ltx2_image2video.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 73d08e6b1a20..6ffc23725022 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -222,7 +222,7 @@ def test_inference(self): ) expected_audio_slice = torch.tensor( [ - 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index 9c58b4fc413d..1edae9c0e098 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -224,7 +224,7 @@ def test_inference(self): ) expected_audio_slice = torch.tensor( [ - 0.0229, 0.0503, 0.1220, 0.1083, 0.1745, 0.1075, 0.1779, 0.0974, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 ] ) # fmt: on From 040c1188d962d994d58d901af0b0347096ce1472 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Jan 2026 12:14:25 +0100 Subject: [PATCH 74/86] Fix bug when calculating audio RoPE coords --- src/diffusers/models/transformers/transformer_ltx2.py | 4 +--- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 +- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index bf3f3d13c59b..bc2559ebbc41 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -761,11 +761,9 @@ def prepare_audio_coords( """ # 1. Generate coordinates in the frame (time) dimension. - audio_duration_s = num_frames / fps - latent_frames = int(audio_duration_s * self.audio_latents_per_second) # Always compute rope in fp32 grid_f = torch.arange( - start=shift, end=latent_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device ) # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index e95b8d5c0b2e..cbfb5b5c4a1b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -689,7 +689,7 @@ def prepare_audio_latents( latents_per_second = ( float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) ) - latent_length = int(duration_s * latents_per_second) + latent_length = round(duration_s * latents_per_second) if latents is not None: return latents.to(device=device, dtype=dtype), latent_length diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 5a4b27280958..652955fee129 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -749,7 +749,7 @@ def prepare_audio_latents( latents_per_second = ( float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) ) - latent_length = int(duration_s * latents_per_second) + latent_length = round(duration_s * latents_per_second) if latents is not None: return latents.to(device=device, dtype=dtype), latent_length From 44925cb3f56e63dfc62e8de9103f1127b00635d0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 8 Jan 2026 05:16:19 +0530 Subject: [PATCH 75/86] Ltx 2 latent upsample pipeline (#12922) * Initial implementation of LTX 2.0 latent upsampling pipeline * Add new LTX 2.0 spatial latent upsampler logic * Add test script for LTX 2.0 latent upsampling * Add option to enable VAE tiling in upsampling test script * Get latent upsampler working with video latents * Fix typo in BlurDownsample * Add latent upsample pipeline docstring and example * Remove deprecated pipeline VAE slicing/tiling methods * make style and make quality * When returning latents, return unpacked and denormalized latents for T2V and I2V * Add model_cpu_offload_seq for latent upsampling pipeline --------- Co-authored-by: Daniel Gu --- scripts/convert_ltx2_to_diffusers.py | 48 +- scripts/ltx2_test_latent_upsampler.py | 174 +++++++ src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/ltx2/__init__.py | 4 + .../pipelines/ltx2/latent_upsampler.py | 285 ++++++++++++ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 32 +- .../ltx2/pipeline_ltx2_image2video.py | 32 +- .../ltx2/pipeline_ltx2_latent_upsample.py | 432 ++++++++++++++++++ 9 files changed, 980 insertions(+), 33 deletions(-) create mode 100644 scripts/ltx2_test_latent_upsampler.py create mode 100644 src/diffusers/pipelines/ltx2/latent_upsampler.py create mode 100644 src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index eb0b010075b4..fa6da10bee78 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -16,7 +16,7 @@ LTX2Pipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available @@ -577,6 +577,33 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder +def get_ltx2_spatial_latent_upsampler_config(version: str): + if version == "2.0": + config = { + "in_channels": 128, + "mid_channels": 1024, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + "rational_spatial_scale": 2.0, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + +def convert_ltx2_spatial_latent_upsampler( + original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype +): + with init_empty_weights(): + latent_upsampler = LTX2LatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) @@ -682,6 +709,12 @@ def get_args(): type=str, help="HF Hub id for the LTX 2.0 text tokenizer", ) + parser.add_argument( + "--latent_upsampler_filename", + default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors", + type=str, + help="Latent upsampler filename", + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") @@ -689,6 +722,7 @@ def get_args(): parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") + parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler") parser.add_argument( "--full_pipeline", action="store_true", @@ -788,6 +822,18 @@ def main(args): if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) + if args.latent_upsampler: + original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( + repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename + ) + latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version) + latent_upsampler = convert_ltx2_spatial_latent_upsampler( + original_latent_upsampler_ckpt, + latent_upsampler_config, + dtype=vae_dtype, + ) + latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) + if args.full_pipeline: scheduler = FlowMatchEulerDiscreteScheduler( use_dynamic_shifting=True, diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py new file mode 100644 index 000000000000..6b2e088f23c9 --- /dev/null +++ b/scripts/ltx2_test_latent_upsampler.py @@ -0,0 +1,174 @@ +import argparse +import gc +import os + +import torch + +from diffusers import AutoencoderKLLTX2Video +from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.export_utils import encode_video +from diffusers.utils import load_image + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") + parser.add_argument("--revision", type=str, default="main") + + parser.add_argument("--image_path", required=True, type=str) + parser.add_argument( + "--prompt", + type=str, + default=( + "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart " + "in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in " + "slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless " + "motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep " + "darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and " + "scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground " + "dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity " + "motion, cinematic lighting, and a breath-taking, movie-like shot." + ), + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=( + "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion " + "artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + ), + ) + + parser.add_argument("--num_inference_steps", type=int, default=40) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=768) + parser.add_argument("--num_frames", type=int, default=121) + parser.add_argument("--frame_rate", type=float, default=25.0) + parser.add_argument("--guidance_scale", type=float, default=3.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--apply_scheduler_fix", action="store_true") + + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--dtype", type=str, default="bf16") + parser.add_argument("--cpu_offload", action="store_true") + parser.add_argument("--vae_tiling", action="store_true") + parser.add_argument("--use_video_latents", action="store_true") + + parser.add_argument( + "--output_dir", + type=str, + default="samples", + help="Output directory for generated video", + ) + parser.add_argument( + "--output_filename", + type=str, + default="ltx2_i2v_video_upsampled.mp4", + help="Filename of the exported generated video", + ) + + args = parser.parse_args() + args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + return args + + +def main(args): + pipeline = LTX2ImageToVideoPipeline.from_pretrained( + args.model_id, + revision=args.revision, + torch_dtype=args.dtype, + ) + if args.cpu_offload: + pipeline.enable_model_cpu_offload() + else: + pipeline.to(device=args.device) + + image = load_image(args.image_path) + + first_stage_output_type = "pil" + if args.use_video_latents: + first_stage_output_type = "latent" + + video, audio = pipeline( + image=image, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device=args.device).manual_seed(args.seed), + output_type=first_stage_output_type, + return_dict=False, + ) + + if args.use_video_latents: + # Manually convert the audio latents to a waveform + audio = audio.to(pipeline.audio_vae.dtype) + audio = pipeline.audio_vae.decode(audio, return_dict=False)[0] + audio = pipeline.vocoder(audio) + + # Get some pipeline configs for upsampling + spatial_patch_size = pipeline.transformer_spatial_patch_size + temporal_patch_size = pipeline.transformer_temporal_patch_size + + # upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained( + # args.model_id, revision=args.revision, torch_dtype=args.dtype, + # ) + output_sampling_rate = pipeline.vocoder.config.output_sampling_rate + del pipeline # Otherwise there might be an OOM error? + torch.cuda.empty_cache() + gc.collect() + + vae = AutoencoderKLLTX2Video.from_pretrained( + args.model_id, + subfolder="vae", + revision=args.revision, + torch_dtype=args.dtype, + ) + latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + args.model_id, + subfolder="latent_upsampler", + revision=args.revision, + torch_dtype=args.dtype, + ) + upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) + upsample_pipeline.to(device=args.device) + if args.vae_tiling: + upsample_pipeline.vae.enable_tiling() + + upsample_kwargs = { + "height": args.height, + "width": args.width, + "output_type": "np", + "return_dict": False, + } + if args.use_video_latents: + upsample_kwargs["latents"] = video + upsample_kwargs["num_frames"] = args.num_frames + upsample_kwargs["spatial_patch_size"] = spatial_patch_size + upsample_kwargs["temporal_patch_size"] = temporal_patch_size + else: + upsample_kwargs["video"] = video + + video = upsample_pipeline(**upsample_kwargs)[0] + + # Convert video to uint8 (but keep as NumPy array) + video = (video * 255).round().astype("uint8") + video = torch.from_numpy(video) + + encode_video( + video[0], + fps=args.frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=output_sampling_rate, # should be 24000 + output_path=os.path.join(args.output_dir, args.output_filename), + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 770afd24e0b5..c749bad4be47 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -542,6 +542,7 @@ "LongCatImageEditPipeline", "LongCatImagePipeline", "LTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", "LTX2Pipeline", "LTXConditionPipeline", "LTXI2VLongMultiPromptPipeline", @@ -1263,6 +1264,7 @@ LongCatImageEditPipeline, LongCatImagePipeline, LTX2ImageToVideoPipeline, + LTX2LatentUpsamplePipeline, LTX2Pipeline, LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 464b24f34069..b94319ffcbdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -290,7 +290,7 @@ "LTXLatentUpsamplePipeline", "LTXI2VLongMultiPromptPipeline", ] - _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -738,7 +738,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) - from .ltx2 import LTX2ImageToVideoPipeline, LTX2Pipeline + from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index 2760f8f7feeb..115e83e827a4 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -23,8 +23,10 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] + _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -36,8 +38,10 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .connectors import LTX2TextConnectors + from .latent_upsampler import LTX2LatentUpsamplerModel from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py new file mode 100644 index 000000000000..69a9b1d9193f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -0,0 +1,285 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: Optional[float] = 2.0, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_spatial_scale is not None: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index cbfb5b5c4a1b..3e47a695d5ef 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1082,21 +1082,27 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1121,10 +1127,6 @@ def __call__( video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 652955fee129..bff977e6a711 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -1179,21 +1179,27 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + if output_type == "latent": video = latents audio = audio_latents else: - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) latents = latents.to(prompt_embeds.dtype) if not self.vae.config.timestep_conditioning: @@ -1218,10 +1224,6 @@ def __call__( video = self.video_processor.postprocess_video(video, output_type=output_type) audio_latents = audio_latents.to(self.audio_vae.dtype) - audio_latents = self._denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] audio = self.vocoder(generated_mel_spectrograms) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py new file mode 100644 index 000000000000..94cddad42fc3 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -0,0 +1,432 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTX2Video +from ...utils import get_logger, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .latent_upsampler import LTX2LatentUpsamplerModel + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2 + >>> from diffusers.utils import load_image + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=25.0, + ... num_inference_steps=40, + ... guidance_scale=3.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained( + ... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe.to("cuda") + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="pil", + ... return_dict=False, + ... )[0] + + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + >>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTX2LatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "vae->latent_upsampler" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + latent_upsampler: LTX2LatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + latents: Optional[torch.Tensor] = None, + latents_normalized: bool = False, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents_supplied = latents is not None + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) From 5e50046728f0ee24fd1e8f3d0e6c1bc2cab36861 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 01:01:30 +0100 Subject: [PATCH 76/86] Fix latent upsampler filename in LTX 2 conversion script --- scripts/convert_ltx2_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index fa6da10bee78..d84722f278c2 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -711,7 +711,7 @@ def get_args(): ) parser.add_argument( "--latent_upsampler_filename", - default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors", + default="ltx-2-spatial-upscaler-x2-1.0.safetensors", type=str, help="Latent upsampler filename", ) From 2b85b93e68856139b339b201f1a44dfe8094eebe Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 01:18:37 +0100 Subject: [PATCH 77/86] Add latent upsample pipeline to LTX 2 docs --- docs/source/en/api/pipelines/ltx2.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md index c71def1ab716..231e3112a907 100644 --- a/docs/source/en/api/pipelines/ltx2.md +++ b/docs/source/en/api/pipelines/ltx2.md @@ -32,6 +32,12 @@ The original codebase for LTX-2 can be found [here](https://github.com/Lightrick - all - __call__ +## LTX2LatentUpsamplePipeline + +[[autodoc]] LTX2LatentUpsamplePipeline + - all + - __call__ + ## LTX2PipelineOutput [[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput From 40ee3e33dd7cd25a03c34925bc7c37a7ea4fc996 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 01:21:39 +0100 Subject: [PATCH 78/86] Add dummy objects for LTX 2 latent upsample pipeline --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5781f60249d9..a7f0c5b85dd8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1892,6 +1892,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2LatentUpsamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTX2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 99ff722e08040b14a9a7a1d5a3633dd1ea115590 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 01:48:31 +0100 Subject: [PATCH 79/86] Set default FPS to official LTX 2 ckpt default of 24.0 --- src/diffusers/models/transformers/transformer_ltx2.py | 7 +++---- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 6 +++--- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index bc2559ebbc41..bf2a02c80bda 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -667,7 +667,7 @@ def prepare_video_coords( height: int, width: int, device: torch.device, - fps: float = 25.0, + fps: float = 24.0, ) -> torch.Tensor: """ Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel @@ -735,7 +735,6 @@ def prepare_audio_coords( batch_size: int, num_frames: int, device: torch.device, - fps: float = 25.0, shift: int = 0, ) -> torch.Tensor: """ @@ -1115,7 +1114,7 @@ def forward( num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, - fps: float = 25.0, + fps: float = 24.0, audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, audio_coords: Optional[torch.Tensor] = None, @@ -1176,7 +1175,7 @@ def forward( ) if audio_coords is None: audio_coords = self.audio_rope.prepare_audio_coords( - batch_size, audio_num_frames, audio_hidden_states.device, fps=fps + batch_size, audio_num_frames, audio_hidden_states.device ) video_rotary_emb = self.rope(video_coords, device=hidden_states.device) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 3e47a695d5ef..00184300a1e6 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -746,7 +746,7 @@ def __call__( height: int = 512, width: int = 768, num_frames: int = 121, - frame_rate: float = 25.0, + frame_rate: float = 24.0, num_inference_steps: int = 40, timesteps: List[int] = None, guidance_scale: float = 3.0, @@ -781,7 +781,7 @@ def __call__( The width in pixels of the generated image. This is set to 848 by default for the best results. num_frames (`int`, *optional*, defaults to `121`): The number of video frames to generate - frame_rate (`float`, *optional*, defaults to `25.0`): + frame_rate (`float`, *optional*, defaults to `24.0`): The frames per second (FPS) of the generated video. num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -996,7 +996,7 @@ def __call__( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate ) audio_coords = self.transformer.audio_rope.prepare_audio_coords( - audio_latents.shape[0], audio_num_frames, audio_latents.device, fps=frame_rate + audio_latents.shape[0], audio_num_frames, audio_latents.device ) # 7. Denoising loop diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index bff977e6a711..1dd28a127661 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -807,7 +807,7 @@ def __call__( height: int = 512, width: int = 768, num_frames: int = 121, - frame_rate: float = 25.0, + frame_rate: float = 24.0, num_inference_steps: int = 40, timesteps: List[int] = None, guidance_scale: float = 3.0, @@ -844,7 +844,7 @@ def __call__( The width in pixels of the generated image. This is set to 848 by default for the best results. num_frames (`int`, *optional*, defaults to `121`): The number of video frames to generate - frame_rate (`float`, *optional*, defaults to `25.0`): + frame_rate (`float`, *optional*, defaults to `24.0`): The frames per second (FPS) of the generated video. num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -1067,7 +1067,7 @@ def __call__( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate ) audio_coords = self.transformer.audio_rope.prepare_audio_coords( - audio_latents.shape[0], audio_num_frames, audio_latents.device, fps=frame_rate + audio_latents.shape[0], audio_num_frames, audio_latents.device ) # 7. Denoising loop From 165b945450e5483c2a5dc2c4f42a6100a9037c97 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 02:31:02 +0100 Subject: [PATCH 80/86] Set default CFG scale to official LTX 2 ckpt default of 4.0 --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 4 ++-- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 00184300a1e6..6a964e34b9e2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -749,7 +749,7 @@ def __call__( frame_rate: float = 24.0, num_inference_steps: int = 40, timesteps: List[int] = None, - guidance_scale: float = 3.0, + guidance_scale: float = 4.0, guidance_rescale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -790,7 +790,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to `3.0`): + guidance_scale (`float`, *optional*, defaults to `4.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 1dd28a127661..1e859824a2da 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -810,7 +810,7 @@ def __call__( frame_rate: float = 24.0, num_inference_steps: int = 40, timesteps: List[int] = None, - guidance_scale: float = 3.0, + guidance_scale: float = 4.0, guidance_rescale: float = 0.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -853,7 +853,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to `3.0`): + guidance_scale (`float`, *optional*, defaults to `4.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting From 1a4ae58cda76a51613ec7a15fe053af6a84420fd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 03:16:41 +0100 Subject: [PATCH 81/86] Update LTX 2 pipeline example docstrings --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 4 ++- .../ltx2/pipeline_ltx2_image2video.py | 2 ++ .../ltx2/pipeline_ltx2_latent_upsample.py | 28 ++++++++++++------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 6a964e34b9e2..99d6b71ec3d7 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -62,8 +62,10 @@ ... negative_prompt=negative_prompt, ... width=768, ... height=512, - ... frame_rate=frame_rate, ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, ... output_type="np", ... return_dict=False, ... ) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 1e859824a2da..b1711e283191 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -70,6 +70,8 @@ ... height=512, ... num_frames=121, ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, ... output_type="np", ... return_dict=False, ... ) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 94cddad42fc3..4d9c2c81f0dc 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -33,12 +33,12 @@ Examples: ```py >>> import torch - >>> from diffusers import LTX2ImageToVideoPipeline, LTX2 - >>> from diffusers.utils import load_image + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image - >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() >>> image = load_image( ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" @@ -46,6 +46,7 @@ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + >>> frame_rate = 24.0 >>> video, audio = pipe( ... image=image, ... prompt=prompt, @@ -53,29 +54,36 @@ ... width=768, ... height=512, ... num_frames=121, - ... frame_rate=25.0, + ... frame_rate=frame_rate, ... num_inference_steps=40, - ... guidance_scale=3.0, + ... guidance_scale=4.0, ... output_type="pil", ... return_dict=False, ... ) >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained( - ... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16 + ... "Lightricks/LTX-2", torch_dtype=torch.bfloat16 ... ) + >>> upsample_pipe.vae.enable_tiling() >>> upsample_pipe.to("cuda") >>> video = upsample_pipe( ... video=video, ... width=768, ... height=512, - ... output_type="pil", + ... output_type="np", ... return_dict=False, ... )[0] - >>> video = (video * 255).round().astype("uint8") >>> video = torch.from_numpy(video) - >>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4") + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) ``` """ From b4d33df989614f751efee673af6e236a5fceb435 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 03:24:44 +0100 Subject: [PATCH 82/86] make style and make quality --- src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 4d9c2c81f0dc..901623d0f5a9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -61,9 +61,7 @@ ... return_dict=False, ... ) - >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained( - ... "Lightricks/LTX-2", torch_dtype=torch.bfloat16 - ... ) + >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) >>> upsample_pipe.vae.enable_tiling() >>> upsample_pipe.to("cuda") From 724afee974492e7219ea8be5d19988fb9db30b7c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 03:26:29 +0100 Subject: [PATCH 83/86] Remove LTX 2 test scripts --- scripts/ltx2_test_full_pipeline.py | 108 --------------- scripts/ltx2_test_full_pipeline_i2v.py | 102 --------------- scripts/ltx2_test_latent_upsampler.py | 174 ------------------------- scripts/test_ltx2_audio_conversion.py | 119 ----------------- 4 files changed, 503 deletions(-) delete mode 100644 scripts/ltx2_test_full_pipeline.py delete mode 100644 scripts/ltx2_test_full_pipeline_i2v.py delete mode 100644 scripts/ltx2_test_latent_upsampler.py delete mode 100644 scripts/test_ltx2_audio_conversion.py diff --git a/scripts/ltx2_test_full_pipeline.py b/scripts/ltx2_test_full_pipeline.py deleted file mode 100644 index 20eb65fde25c..000000000000 --- a/scripts/ltx2_test_full_pipeline.py +++ /dev/null @@ -1,108 +0,0 @@ -import argparse -import os - -import torch - -from diffusers import LTX2Pipeline -from diffusers.pipelines.ltx2.export_utils import encode_video - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") - parser.add_argument("--revision", type=str, default="main") - - parser.add_argument( - "--prompt", - type=str, - default="A video of a dog dancing to energetic electronic dance music", - ) - parser.add_argument( - "--negative_prompt", - type=str, - default=( - "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " - "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " - "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " - "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " - "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " - "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " - "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " - "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " - "off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " - "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " - "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." - ), - ) - - parser.add_argument("--num_inference_steps", type=int, default=40) - parser.add_argument("--height", type=int, default=512) - parser.add_argument("--width", type=int, default=768) - parser.add_argument("--num_frames", type=int, default=121) - parser.add_argument("--frame_rate", type=float, default=25.0) - parser.add_argument("--guidance_scale", type=float, default=3.0) - parser.add_argument("--seed", type=int, default=42) - - parser.add_argument("--device", type=str, default="cuda:0") - parser.add_argument("--dtype", type=str, default="bf16") - parser.add_argument("--cpu_offload", action="store_true") - - parser.add_argument( - "--output_dir", - type=str, - default="/home/daniel_gu/samples", - help="Output directory for generated video", - ) - parser.add_argument( - "--output_filename", - type=str, - default="ltx2_sample_video.mp4", - help="Filename of the exported generated video", - ) - - args = parser.parse_args() - args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 - return args - - -def main(args): - pipeline = LTX2Pipeline.from_pretrained( - args.model_id, - revision=args.revision, - torch_dtype=args.dtype, - ) - pipeline.to(device=args.device) - if args.cpu_offload: - pipeline.enable_model_cpu_offload() - - video, audio = pipeline( - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - frame_rate=args.frame_rate, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device=args.device).manual_seed(args.seed), - output_type="np", - return_dict=False, - ) - - # Convert video to uint8 (but keep as NumPy array) - video = (video * 255).round().astype("uint8") - video = torch.from_numpy(video) - - encode_video( - video[0], - fps=args.frame_rate, - audio=audio[0].float().cpu(), - audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000 - output_path=os.path.join(args.output_dir, args.output_filename), - ) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/scripts/ltx2_test_full_pipeline_i2v.py b/scripts/ltx2_test_full_pipeline_i2v.py deleted file mode 100644 index cbe61eecdf7c..000000000000 --- a/scripts/ltx2_test_full_pipeline_i2v.py +++ /dev/null @@ -1,102 +0,0 @@ -import argparse -import os - -import torch - -from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline -from diffusers.pipelines.ltx2.export_utils import encode_video -from diffusers.utils import load_image - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") - parser.add_argument("--revision", type=str, default="main") - - parser.add_argument("--image_path", required=True, type=str) - parser.add_argument( - "--prompt", - type=str, - default="An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.", - ) - parser.add_argument( - "--negative_prompt", - type=str, - default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.", - ) - - parser.add_argument("--num_inference_steps", type=int, default=40) - parser.add_argument("--height", type=int, default=512) - parser.add_argument("--width", type=int, default=768) - parser.add_argument("--num_frames", type=int, default=121) - parser.add_argument("--frame_rate", type=float, default=25.0) - parser.add_argument("--guidance_scale", type=float, default=3.0) - parser.add_argument("--seed", type=int, default=42) - - parser.add_argument("--device", type=str, default="cuda:0") - parser.add_argument("--dtype", type=str, default="bf16") - parser.add_argument("--cpu_offload", action="store_true") - - parser.add_argument( - "--output_dir", - type=str, - default="samples", - help="Output directory for generated video", - ) - parser.add_argument( - "--output_filename", - type=str, - default="ltx2_sample_video.mp4", - help="Filename of the exported generated video", - ) - - args = parser.parse_args() - args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 - return args - - -def main(args): - pipeline = LTX2ImageToVideoPipeline.from_pretrained( - args.model_id, - revision=args.revision, - torch_dtype=args.dtype, - ) - if args.cpu_offload: - pipeline.enable_model_cpu_offload() - else: - pipeline.to(device=args.device) - - image = load_image(args.image_path) - - video, audio = pipeline( - image=image, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - frame_rate=args.frame_rate, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device=args.device).manual_seed(args.seed), - output_type="np", - return_dict=False, - ) - - # Convert video to uint8 (but keep as NumPy array) - video = (video * 255).round().astype("uint8") - video = torch.from_numpy(video) - - encode_video( - video[0], - fps=args.frame_rate, - audio=audio[0].float().cpu(), - audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000 - output_path=os.path.join(args.output_dir, args.output_filename), - ) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/scripts/ltx2_test_latent_upsampler.py b/scripts/ltx2_test_latent_upsampler.py deleted file mode 100644 index 6b2e088f23c9..000000000000 --- a/scripts/ltx2_test_latent_upsampler.py +++ /dev/null @@ -1,174 +0,0 @@ -import argparse -import gc -import os - -import torch - -from diffusers import AutoencoderKLLTX2Video -from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel -from diffusers.pipelines.ltx2.export_utils import encode_video -from diffusers.utils import load_image - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model") - parser.add_argument("--revision", type=str, default="main") - - parser.add_argument("--image_path", required=True, type=str) - parser.add_argument( - "--prompt", - type=str, - default=( - "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart " - "in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in " - "slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless " - "motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep " - "darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and " - "scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground " - "dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity " - "motion, cinematic lighting, and a breath-taking, movie-like shot." - ), - ) - parser.add_argument( - "--negative_prompt", - type=str, - default=( - "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion " - "artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." - ), - ) - - parser.add_argument("--num_inference_steps", type=int, default=40) - parser.add_argument("--height", type=int, default=512) - parser.add_argument("--width", type=int, default=768) - parser.add_argument("--num_frames", type=int, default=121) - parser.add_argument("--frame_rate", type=float, default=25.0) - parser.add_argument("--guidance_scale", type=float, default=3.0) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--apply_scheduler_fix", action="store_true") - - parser.add_argument("--device", type=str, default="cuda:0") - parser.add_argument("--dtype", type=str, default="bf16") - parser.add_argument("--cpu_offload", action="store_true") - parser.add_argument("--vae_tiling", action="store_true") - parser.add_argument("--use_video_latents", action="store_true") - - parser.add_argument( - "--output_dir", - type=str, - default="samples", - help="Output directory for generated video", - ) - parser.add_argument( - "--output_filename", - type=str, - default="ltx2_i2v_video_upsampled.mp4", - help="Filename of the exported generated video", - ) - - args = parser.parse_args() - args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 - return args - - -def main(args): - pipeline = LTX2ImageToVideoPipeline.from_pretrained( - args.model_id, - revision=args.revision, - torch_dtype=args.dtype, - ) - if args.cpu_offload: - pipeline.enable_model_cpu_offload() - else: - pipeline.to(device=args.device) - - image = load_image(args.image_path) - - first_stage_output_type = "pil" - if args.use_video_latents: - first_stage_output_type = "latent" - - video, audio = pipeline( - image=image, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - frame_rate=args.frame_rate, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - generator=torch.Generator(device=args.device).manual_seed(args.seed), - output_type=first_stage_output_type, - return_dict=False, - ) - - if args.use_video_latents: - # Manually convert the audio latents to a waveform - audio = audio.to(pipeline.audio_vae.dtype) - audio = pipeline.audio_vae.decode(audio, return_dict=False)[0] - audio = pipeline.vocoder(audio) - - # Get some pipeline configs for upsampling - spatial_patch_size = pipeline.transformer_spatial_patch_size - temporal_patch_size = pipeline.transformer_temporal_patch_size - - # upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained( - # args.model_id, revision=args.revision, torch_dtype=args.dtype, - # ) - output_sampling_rate = pipeline.vocoder.config.output_sampling_rate - del pipeline # Otherwise there might be an OOM error? - torch.cuda.empty_cache() - gc.collect() - - vae = AutoencoderKLLTX2Video.from_pretrained( - args.model_id, - subfolder="vae", - revision=args.revision, - torch_dtype=args.dtype, - ) - latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( - args.model_id, - subfolder="latent_upsampler", - revision=args.revision, - torch_dtype=args.dtype, - ) - upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) - upsample_pipeline.to(device=args.device) - if args.vae_tiling: - upsample_pipeline.vae.enable_tiling() - - upsample_kwargs = { - "height": args.height, - "width": args.width, - "output_type": "np", - "return_dict": False, - } - if args.use_video_latents: - upsample_kwargs["latents"] = video - upsample_kwargs["num_frames"] = args.num_frames - upsample_kwargs["spatial_patch_size"] = spatial_patch_size - upsample_kwargs["temporal_patch_size"] = temporal_patch_size - else: - upsample_kwargs["video"] = video - - video = upsample_pipeline(**upsample_kwargs)[0] - - # Convert video to uint8 (but keep as NumPy array) - video = (video * 255).round().astype("uint8") - video = torch.from_numpy(video) - - encode_video( - video[0], - fps=args.frame_rate, - audio=audio[0].float().cpu(), - audio_sample_rate=output_sampling_rate, # should be 24000 - output_path=os.path.join(args.output_dir, args.output_filename), - ) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/scripts/test_ltx2_audio_conversion.py b/scripts/test_ltx2_audio_conversion.py deleted file mode 100644 index 3aa2a65d3f16..000000000000 --- a/scripts/test_ltx2_audio_conversion.py +++ /dev/null @@ -1,119 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from huggingface_hub import hf_hub_download - - -def download_checkpoint( - repo_id="diffusers-internal-dev/new-ltx-model", - filename="ltx-av-step-1932500-interleaved-new-vae.safetensors", -): - ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) - return ckpt_path - - -def convert_state_dict(state_dict: dict) -> dict: - converted = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - continue - new_key = key - if new_key.startswith("decoder."): - new_key = new_key[len("decoder.") :] - converted[f"decoder.{new_key}"] = value - - converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means") - converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means") - return converted - - -def load_original_decoder(device: torch.device, dtype: torch.dtype): - from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder - from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER - from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator - - checkpoint_path = download_checkpoint() - - # The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py` - decoder = Builder( - model_path=checkpoint_path, - model_class_configurator=AudioDecoderConfigurator, - model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, - ).build(device=device) - - decoder.eval() - return decoder - - -def build_diffusers_decoder(): - from diffusers.models.autoencoders import AutoencoderKLLTX2Audio - - with torch.device("meta"): - model = AutoencoderKLLTX2Audio() - - model.eval() - return model - - -@torch.no_grad() -def main() -> None: - parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.") - parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) - parser.add_argument("--batch", type=int, default=2) - parser.add_argument("--output-path", type=Path, required=True) - args = parser.parse_args() - - device = torch.device(args.device) - dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} - dtype = dtype_map[args.dtype] - - original_decoder = load_original_decoder(device, dtype) - diffusers_model = build_diffusers_decoder() - - converted_state_dict = convert_state_dict(original_decoder.state_dict()) - diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False) - - per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel() - latent_channels = diffusers_model.decoder.latent_channels - mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None - - levels = len(diffusers_model.decoder.channel_multipliers) - latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1)) - latent_width = mel_bins_for_match or latent_height - - dummy = torch.randn( - args.batch, - diffusers_model.decoder.latent_channels, - latent_height, - latent_width, - device=device, - dtype=dtype, - generator=torch.Generator(device).manual_seed(42), - ) - - original_out = original_decoder(dummy) - - from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline - - _, a_channels, a_time, a_freq = dummy.shape - dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq) - dummy = LTX2Pipeline._denormalize_audio_latents( - dummy, - diffusers_model.latents_mean, - diffusers_model.latents_std, - ) - dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3) - diffusers_out = diffusers_model.decode(dummy).sample - - torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4) - max_diff = (diffusers_out - original_out).abs().max().item() - print(f"Conversion successful. Max diff: {max_diff:.6f}") - - diffusers_model.to(dtype).save_pretrained(args.output_path) - print(f"Serialized model to {args.output_path}") - - -if __name__ == "__main__": - main() From d24faa7163070845acffcbbd82df145416dec5cb Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 04:51:33 +0100 Subject: [PATCH 84/86] Fix LTX 2 upsample pipeline example docstring --- .../pipelines/ltx2/pipeline_ltx2_latent_upsample.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 901623d0f5a9..a44c40b0430f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -35,6 +35,7 @@ >>> import torch >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel >>> from diffusers.utils import load_image >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) @@ -61,9 +62,12 @@ ... return_dict=False, ... ) - >>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + ... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) >>> upsample_pipe.vae.enable_tiling() - >>> upsample_pipe.to("cuda") + >>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16) >>> video = upsample_pipe( ... video=video, From 353f0dbdd06086cda8bb947038fedf0f2fe86e3f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 04:52:18 +0100 Subject: [PATCH 85/86] Add logic to convert and save a LTX 2 upsampling pipeline --- scripts/convert_ltx2_to_diffusers.py | 33 +++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index d84722f278c2..5367113365a2 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -13,6 +13,7 @@ AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, + LTX2LatentUpsamplePipeline, LTX2Pipeline, LTX2VideoTransformer3DModel, ) @@ -728,6 +729,11 @@ def get_args(): action="store_true", help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)", ) + parser.add_argument( + "--upsample_pipeline", + action="store_true", + help="Whether to save a latent upsampling pipeline", + ) parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) @@ -762,18 +768,26 @@ def main(args): combined_ckpt = None load_combined_models = any( - [args.vae, args.audio_vae, args.dit, args.vocoder, args.text_encoder, args.full_pipeline] + [ + args.vae, + args.audio_vae, + args.dit, + args.vocoder, + args.text_encoder, + args.full_pipeline, + args.upsample_pipeline, + ] ) if args.combined_filename is not None and load_combined_models: combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) - if args.vae or args.full_pipeline: + if args.vae or args.full_pipeline or args.upsample_pipeline: if args.vae_filename is not None: original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) - if not args.full_pipeline: + if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) if args.audio_vae or args.full_pipeline: @@ -822,7 +836,7 @@ def main(args): if not args.full_pipeline: tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) - if args.latent_upsampler: + if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline: original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename ) @@ -832,7 +846,8 @@ def main(args): latent_upsampler_config, dtype=vae_dtype, ) - latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) + if not args.full_pipeline and not args.upsample_pipeline: + latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) if args.full_pipeline: scheduler = FlowMatchEulerDiscreteScheduler( @@ -857,6 +872,14 @@ def main(args): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.upsample_pipeline: + pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) + + # Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline + pipe.save_pretrained( + os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB" + ) + if __name__ == "__main__": args = get_args() From f85b969a3d6e515765dcfd9816fed693e5883c4d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 8 Jan 2026 06:01:34 +0100 Subject: [PATCH 86/86] Document LTX2VideoTransformer3DModel forward pass --- .../models/transformers/transformer_ltx2.py | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index bf2a02c80bda..b88f096e8033 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1126,12 +1126,44 @@ def forward( Args: hidden_states (`torch.Tensor`): - Input patchified video latents of shape (batch_size, num_video_tokens, in_channels). + Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`. audio_hidden_states (`torch.Tensor`): - Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels). + Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`. encoder_hidden_states (`torch.Tensor`): - Input text embeddings of shape TODO. - TODO for the rest. + Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + audio_encoder_hidden_states (`torch.Tensor`): + Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + timestep (`torch.Tensor`): + Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by + `self.config.timestep_scale_multiplier`. + audio_timestep (`torch.Tensor`, *optional*): + Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation + params. This is only used by certain pipelines such as the I2V pipeline. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. + audio_encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + num_frames (`int`, *optional*): + The number of latent video frames. Used if calculating the video coordinates for RoPE. + height (`int`, *optional*): + The latent video height. Used if calculating the video coordinates for RoPE. + width (`int`, *optional*): + The latent video width. Used if calculating the video coordinates for RoPE. + fps: (`float`, *optional*, defaults to `24.0`): + The desired frames per second of the generated video. Used if calculating the video coordinates for + RoPE. + audio_num_frames: (`int`, *optional*): + The number of latent audio frames. Used if calculating the audio coordinates for RoPE. + video_coords (`torch.Tensor`, *optional*): + The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + audio_coords (`torch.Tensor`, *optional*): + The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + attention_kwargs (`Dict[str, Any]`, *optional*): + Optional dict of keyword args to be passed to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. Returns: `AudioVisualModelOutput` or `tuple`: