diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f0cb0164436e..ed5f01a0250d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -367,6 +367,8 @@ title: LatteTransformer3DModel - local: api/models/longcat_image_transformer2d title: LongCatImageTransformer2DModel + - local: api/models/ltx2_video_transformer3d + title: LTX2VideoTransformer3DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel - local: api/models/lumina2_transformer2d @@ -443,6 +445,10 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoder_kl_hunyuan_video15 title: AutoencoderKLHunyuanVideo15 + - local: api/models/autoencoderkl_audio_ltx_2 + title: AutoencoderKLLTX2Audio + - local: api/models/autoencoderkl_ltx_2 + title: AutoencoderKLLTX2Video - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_magvit @@ -678,6 +684,8 @@ title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte + - local: api/pipelines/ltx2 + title: LTX-2 - local: api/pipelines/ltx_video title: LTXVideo - local: api/pipelines/mochi diff --git a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md new file mode 100644 index 000000000000..d0024474e9e0 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md @@ -0,0 +1,29 @@ + + +# AutoencoderKLLTX2Audio + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Audio + +vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Audio + +[[autodoc]] AutoencoderKLLTX2Audio + - encode + - decode + - all \ No newline at end of file diff --git a/docs/source/en/api/models/autoencoderkl_ltx_2.md b/docs/source/en/api/models/autoencoderkl_ltx_2.md new file mode 100644 index 000000000000..1dbf516c017a --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_ltx_2.md @@ -0,0 +1,29 @@ + + +# AutoencoderKLLTX2Video + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Video + +vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Video + +[[autodoc]] AutoencoderKLLTX2Video + - decode + - encode + - all diff --git a/docs/source/en/api/models/ltx2_video_transformer3d.md b/docs/source/en/api/models/ltx2_video_transformer3d.md new file mode 100644 index 000000000000..9faab8695468 --- /dev/null +++ b/docs/source/en/api/models/ltx2_video_transformer3d.md @@ -0,0 +1,26 @@ + + +# LTX2VideoTransformer3DModel + +A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import LTX2VideoTransformer3DModel + +transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## LTX2VideoTransformer3DModel + +[[autodoc]] LTX2VideoTransformer3DModel diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md new file mode 100644 index 000000000000..231e3112a907 --- /dev/null +++ b/docs/source/en/api/pipelines/ltx2.md @@ -0,0 +1,43 @@ + + +# LTX-2 + +LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. + +You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization. + +The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). + +## LTX2Pipeline + +[[autodoc]] LTX2Pipeline + - all + - __call__ + +## LTX2ImageToVideoPipeline + +[[autodoc]] LTX2ImageToVideoPipeline + - all + - __call__ + +## LTX2LatentUpsamplePipeline + +[[autodoc]] LTX2LatentUpsamplePipeline + - all + - __call__ + +## LTX2PipelineOutput + +[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py new file mode 100644 index 000000000000..5367113365a2 --- /dev/null +++ b/scripts/convert_ltx2_to_diffusers.py @@ -0,0 +1,886 @@ +import argparse +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2LatentUpsamplePipeline, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + +LTX_2_0_TEXT_ENCODER_RENAME_DICT = { + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} + +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + + +def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection.aggregate_embed", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict + + +def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + # Produces a transformer of the same size as used in test_models_transformer_ltx2.py + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 2, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 16, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "caption_channels": 16, + "text_proj_in_factor": 3, + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + }, + } + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + }, + } + + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = {} + + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) + diffusers_config = config["diffusers_config"] + + transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict) + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(transformer_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(transformer_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, transformer_state_dict) + + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) + return transformer + + +def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version) + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + +def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + }, + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + +def get_ltx2_spatial_latent_upsampler_config(version: str): + if version == "2.0": + config = { + "in_channels": 128, + "mid_channels": 1024, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + "rational_spatial_scale": 2.0, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + +def convert_ltx2_spatial_latent_upsampler( + original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype +): + with init_empty_weights(): + latent_upsampler = LTX2LatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + +def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]: + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + ckpt_path = filename + + _, ext = os.path.splitext(ckpt_path) + if ext in [".safetensors", ".sft"]: + state_dict = safetensors.torch.load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + return state_dict + + +def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.replace(prefix, "")] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_key = "text_embedding_projection.aggregate_embed.weight" + if connector_key in combined_ckpt and connector_key not in model_state_dict: + model_state_dict[connector_key] = combined_ckpt[connector_key] + + return model_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--original_state_dict_repo_id", + default="Lightricks/LTX-2", + type=str, + help="HF Hub repo id with LTX 2.0 checkpoint", + ) + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.", + ) + parser.add_argument( + "--version", + type=str, + default="2.0", + choices=["test", "2.0"], + help="Version of the LTX 2.0 model", + ) + + parser.add_argument( + "--combined_filename", + default="ltx-2-19b-dev.safetensors", + type=str, + help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", + ) + parser.add_argument("--vae_prefix", default="vae.", type=str) + parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str) + parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str) + parser.add_argument("--vocoder_prefix", default="vocoder.", type=str) + + parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set") + parser.add_argument( + "--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set" + ) + parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set") + parser.add_argument( + "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set" + ) + parser.add_argument( + "--text_encoder_model_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=str, + help="HF Hub id for the LTX 2.0 base text encoder model", + ) + parser.add_argument( + "--tokenizer_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=str, + help="HF Hub id for the LTX 2.0 text tokenizer", + ) + parser.add_argument( + "--latent_upsampler_filename", + default="ltx-2-spatial-upscaler-x2-1.0.safetensors", + type=str, + help="Latent upsampler filename", + ) + + parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") + parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") + parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") + parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") + parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") + parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler") + parser.add_argument( + "--full_pipeline", + action="store_true", + help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)", + ) + parser.add_argument( + "--upsample_pipeline", + action="store_true", + help="Whether to save a latent upsampling pipeline", + ) + + parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +def main(args): + vae_dtype = DTYPE_MAPPING[args.vae_dtype] + audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype] + dit_dtype = DTYPE_MAPPING[args.dit_dtype] + vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype] + text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype] + + combined_ckpt = None + load_combined_models = any( + [ + args.vae, + args.audio_vae, + args.dit, + args.vocoder, + args.text_encoder, + args.full_pipeline, + args.upsample_pipeline, + ] + ) + if args.combined_filename is not None and load_combined_models: + combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) + + if args.vae or args.full_pipeline or args.upsample_pipeline: + if args.vae_filename is not None: + original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) + elif combined_ckpt is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + if not args.full_pipeline and not args.upsample_pipeline: + vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) + + if args.audio_vae or args.full_pipeline: + if args.audio_vae_filename is not None: + original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename) + elif combined_ckpt is not None: + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version) + if not args.full_pipeline: + audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae")) + + if args.dit or args.full_pipeline: + if args.dit_filename is not None: + original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) + if not args.full_pipeline: + transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.connectors or args.full_pipeline: + if args.dit_filename is not None: + original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version) + if not args.full_pipeline: + connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors")) + + if args.vocoder or args.full_pipeline: + if args.vocoder_filename is not None: + original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename) + elif combined_ckpt is not None: + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version) + if not args.full_pipeline: + vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) + + if args.text_encoder or args.full_pipeline: + # text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id) + if not args.full_pipeline: + text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) + if not args.full_pipeline: + tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) + + if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline: + original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( + repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename + ) + latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version) + latent_upsampler = convert_ltx2_spatial_latent_upsampler( + original_latent_upsampler_ckpt, + latent_upsampler_config, + dtype=vae_dtype, + ) + if not args.full_pipeline and not args.upsample_pipeline: + latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) + + if args.full_pipeline: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) + + pipe = LTX2Pipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.upsample_pipeline: + pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) + + # Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline + pipe.save_pretrained( + os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3a50634d82d8..c749bad4be47 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -193,6 +193,8 @@ "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", + "AutoencoderKLLTX2Audio", + "AutoencoderKLLTX2Video", "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -236,6 +238,7 @@ "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LongCatImageTransformer2DModel", + "LTX2VideoTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", @@ -538,6 +541,9 @@ "LEditsPPPipelineStableDiffusionXL", "LongCatImageEditPipeline", "LongCatImagePipeline", + "LTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", + "LTX2Pipeline", "LTXConditionPipeline", "LTXI2VLongMultiPromptPipeline", "LTXImageToVideoPipeline", @@ -939,6 +945,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -982,6 +990,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatImageTransformer2DModel, + LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, @@ -1254,6 +1263,9 @@ LEditsPPPipelineStableDiffusionXL, LongCatImageEditPipeline, LongCatImagePipeline, + LTX2ImageToVideoPipeline, + LTX2LatentUpsamplePipeline, + LTX2Pipeline, LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, LTXImageToVideoPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c4664f00cad2..81730b7516be 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,8 @@ _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] + _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -104,6 +106,7 @@ _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] + _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] @@ -153,6 +156,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -212,6 +217,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatImageTransformer2DModel, + LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93cd7..8e7a9c81d2ad 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 000000000000..01dd55a938b6 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1521 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = PerChannelRMSNorm() + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm2 = PerChannelRMSNorm() + self.dropout = nn.Dropout(dropout) + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, causal=causal) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, causal=causal) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTX2VideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal + + output_channel = out_channels + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + spatial_padding_mode=spatial_padding_mode, + ) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + causal = causal or self.is_causal + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states, causal=causal) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, causal=causal) + + hidden_states = self.mid_block(hidden_states, causal=causal) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTX2VideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) + else: + hidden_states = self.mid_block(hidden_states, temb, causal=causal) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, causal=causal) + + hidden_states = self.norm_out(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX-2](https://huggingface.co/Lightricks/LTX-2). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTX2VideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + self.decoder = LTX2VideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x, causal=causal) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x, causal=causal) + + enc = self.encoder(x, causal=causal) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, causal=causal) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + dec = self.decoder(z, temb, causal=causal) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb, causal=causal).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile, causal=causal) + else: + tile = self.encoder(tile, causal=causal) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample + else: + decoded = self.decoder(tile, temb, causal=causal) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x, causal=encoder_causal).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb, causal=decoder_causal) + if not return_dict: + return (dec.sample,) + return dec diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..6c9c7dce3d2f --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,804 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.non_linearity = nn.SiLU() + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.dropout = nn.Dropout(dropout) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class LTX2AudioUpsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: + batch, time, _ = audio_latents.shape + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + ) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + self.up = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) + curr_res *= 2 + + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + hidden_features = block(hidden_features, temb=None) + if stage.attn: + hidden_features = stage.attn[block_idx](hidden_features) + + if level != 0 and hasattr(stage, "upsample"): + hidden_features = stage.upsample(hidden_features) + + if self.give_pre_end: + return hidden_features + + hidden = self.norm_out(hidden_features) + hidden = self.non_linearity(hidden) + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE for encoding and decoding audio latent representations. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: Optional[str] = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ) -> None: + super().__init__() + + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.zeros((base_channels,)) + latents_mean = torch.ones((base_channels,)) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True): + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 40b5d4a0dfc9..f0c65202d080 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -35,6 +35,7 @@ from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel + from .transformer_ltx2 import LTX2VideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py new file mode 100644 index 000000000000..b88f096e8033 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -0,0 +1,1350 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + # The cos/sin batch dim may only be broadcastable, so take batch size from x + b = x.shape[0] + _, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. This is typically a video (spatiotemporal) output. + audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`): + The audio output of the audiovisual model. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 2. Prompt Cross-Attention + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 4. Feedforward layers + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + a2v_cross_attention_mask: Optional[torch.Tensor] = None, + v2a_cross_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) + num_pos_dims = coords.shape[1] + + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 2. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 3. Create a 1D grid of frequencies for RoPE + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] + + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # TODO: consider implementing this as a utility and reuse in `connectors.py`. + # src/diffusers/pipelines/ltx2/connectors.py + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: Optional[int] = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: Tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: Optional[int] = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`. + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`. + encoder_hidden_states (`torch.Tensor`): + Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + audio_encoder_hidden_states (`torch.Tensor`): + Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + timestep (`torch.Tensor`): + Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by + `self.config.timestep_scale_multiplier`. + audio_timestep (`torch.Tensor`, *optional*): + Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation + params. This is only used by certain pipelines such as the I2V pipeline. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. + audio_encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + num_frames (`int`, *optional*): + The number of latent video frames. Used if calculating the video coordinates for RoPE. + height (`int`, *optional*): + The latent video height. Used if calculating the video coordinates for RoPE. + width (`int`, *optional*): + The latent video width. Used if calculating the video coordinates for RoPE. + fps: (`float`, *optional*, defaults to `24.0`): + The desired frames per second of the generated video. Used if calculating the video coordinates for + RoPE. + audio_num_frames: (`int`, *optional*): + The number of latent audio frames. Used if calculating the audio coordinates for RoPE. + video_coords (`torch.Tensor`, *optional*): + The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + audio_coords (`torch.Tensor`, *optional*): + The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + attention_kwargs (`Dict[str, Any]`, *optional*): + Optional dict of keyword args to be passed to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7107d84350c7..b94319ffcbdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -290,6 +290,7 @@ "LTXLatentUpsamplePipeline", "LTXI2VLongMultiPromptPipeline", ] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -737,6 +738,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, ) + from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..115e83e827a4 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] + _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] + _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] + _import_structure["vocoder"] = ["LTX2Vocoder"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .connectors import LTX2TextConnectors + from .latent_upsampler import LTX2LatentUpsamplerModel + from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline + from .vocoder import LTX2Vocoder + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py new file mode 100644 index 000000000000..2608c2783f7e --- /dev/null +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -0,0 +1,325 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + rope_type=rope_type, + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + rope_type: str = "interleaved", + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py new file mode 100644 index 000000000000..0bc7a59db228 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -0,0 +1,134 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fractions import Fraction +from typing import Optional + +import torch + +from ...utils import is_av_available + + +_CAN_USE_AV = is_av_available() +if _CAN_USE_AV: + import av +else: + raise ImportError( + "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" + ) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str +) -> None: + video_np = video.cpu().numpy() + + _, height, width, _ = video_np.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py new file mode 100644 index 000000000000..69a9b1d9193f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -0,0 +1,285 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: Optional[float] = 2.0, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_spatial_scale is not None: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py new file mode 100644 index 000000000000..99d6b71ec3d7 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -0,0 +1,1141 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 64, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + duration_s = num_frames / frame_rate + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) + latent_length = round(duration_s * latents_per_second) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, # Video frames, audio frames will be calculated from this + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 000000000000..b1711e283191 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,1238 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 64, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + duration_s = num_frames / frame_rate + latents_per_second = ( + float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + ) + latent_length = round(duration_s * latents_per_second) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, # Video frames, audio frames will be calculated from this + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py new file mode 100644 index 000000000000..a44c40b0430f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -0,0 +1,442 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTX2Video +from ...utils import get_logger, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .latent_upsampler import LTX2LatentUpsamplerModel + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + ... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) + >>> upsample_pipe.vae.enable_tiling() + >>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16) + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="np", + ... return_dict=False, + ... )[0] + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTX2LatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "vae->latent_upsampler" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + latent_upsampler: LTX2LatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + latents: Optional[torch.Tensor] = None, + latents_normalized: bool = False, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents_supplied = latents is not None + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx2/pipeline_output.py b/src/diffusers/pipelines/ltx2/pipeline_output.py new file mode 100644 index 000000000000..eacd571125b0 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + frames: torch.Tensor + audio: torch.Tensor diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 000000000000..217c68103e39 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,159 @@ +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4], + upsample_factors: List[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: List[int] = [3, 7, 11], + resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 440e4539e720..e726bbb46913 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -66,6 +66,7 @@ is_accelerate_version, is_aiter_available, is_aiter_version, + is_av_available, is_better_profanity_available, is_bitsandbytes_available, is_bitsandbytes_version, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 62426ffbf65c..bb94c94da360 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -502,6 +502,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLLTX2Audio(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLLTX2Video(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -1147,6 +1177,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTX2VideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63cec365799b..a7f0c5b85dd8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1877,6 +1877,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTX2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2LatentUpsamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 57b0a337922a..425c360a3110 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_av_available, _av_version = _is_package_available("av") def is_torch_available(): @@ -420,6 +421,10 @@ def is_kornia_available(): return _kornia_available +def is_av_available(): + return _av_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..ce93dfb42afe --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Audio + +from ...testing_utils import ( + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Audio + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 2, # stereo, + "output_channels": 2, + "latent_channels": 4, + "base_channels": 16, + "ch_mult": (1, 2, 4), + "resolution": 16, + "attn_resolutions": None, + "num_res_blocks": 2, + "norm_type": "pixel", + "causality_axis": "height", + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "mel_bins": 16, + "is_causal": True, + "double_z": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 2 + num_frames = 8 + num_mel_bins = 16 + + spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device) + + input_dict = {"sample": spectrogram} + return input_dict + + @property + def input_shape(self): + return (2, 5, 16) + + @property + def output_shape(self): + return (2, 5, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE + def test_output(self): + super().test_output(expected_output_shape=(2, 2, 5, 16)) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py new file mode 100644 index 000000000000..146241361a82 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Video + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Video + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + input_dict = {"sample": image} + return input_dict + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTX2VideoEncoder3d", + "LTX2VideoDecoder3d", + "LTX2VideoDownBlock3D", + "LTX2VideoMidBlock3d", + "LTX2VideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py new file mode 100644 index 000000000000..af9ef0623891 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTX2VideoTransformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + # Common + batch_size = 2 + + # Video + num_frames = 2 + num_channels = 4 + height = 16 + width = 16 + + # Audio + audio_num_frames = 9 + audio_num_channels = 2 + num_mel_bins = 2 + + # Text + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) + audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( + torch_device + ) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) * 1000 + + return { + "hidden_states": hidden_states, + "audio_hidden_states": audio_hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "audio_encoder_hidden_states": audio_encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, + "audio_num_frames": audio_num_frames, + "fps": 25.0, + } + + @property + def input_shape(self): + return (512, 4) + + @property + def output_shape(self): + return (512, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "num_layers": 2, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + "rope_double_precision": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTX2VideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + # torch.manual_seed(seed) + # init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # # Calculate dummy inputs in a custom manner to ensure compatibility with original code + # batch_size = 2 + # num_frames = 9 + # latent_frames = 2 + # text_embedding_dim = 16 + # text_seq_len = 16 + # fps = 25.0 + # sampling_rate = 16000.0 + # hop_length = 160.0 + + # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 + # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + # num_channels = 4 + # latent_height = 4 + # latent_width = 4 + # hidden_states = torch.randn( + # (batch_size, num_channels, latent_frames, latent_height, latent_width), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify video latents (with patch_size (1, 1, 1)) + # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + # encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # audio_num_channels = 2 + # num_mel_bins = 2 + # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + # audio_hidden_states = torch.randn( + # (batch_size, audio_num_channels, latent_length, num_mel_bins), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify audio latents + # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + # audio_encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # inputs_dict = { + # "hidden_states": hidden_states.to(device=torch_device), + # "audio_hidden_states": audio_hidden_states.to(device=torch_device), + # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + # "timestep": timestep, + # "num_frames": latent_frames, + # "height": latent_height, + # "width": latent_width, + # "audio_num_frames": num_frames, + # "fps": 25.0, + # } + + # model = self.model_class.from_pretrained( + # "diffusers-internal-dev/dummy-ltx2", + # subfolder="transformer", + # device_map="cpu", + # ) + # # torch.manual_seed(seed) + # # model = self.model_class(**init_dict) + # model.to(torch_device) + # model.eval() + + # with attention_backend("native"): + # with torch.no_grad(): + # output = model(**inputs_dict) + + # video_output, audio_output = output.to_tuple() + + # self.assertIsNotNone(video_output) + # self.assertIsNotNone(audio_output) + + # # input & output have to have the same shape + # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # # Check against expected slice + # # fmt: off + # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # # fmt: on + + # video_output_flat = video_output.cpu().flatten().float() + # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + # audio_output_flat = audio_output.cpu().flatten().float() + # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + + +class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/ltx2/__init__.py b/tests/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py new file mode 100644 index 000000000000..6ffc23725022 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "output_type", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py new file mode 100644 index 000000000000..1edae9c0e098 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -0,0 +1,241 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2ImageToVideoPipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.rand((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)