From 6732907d4d130653c59d44e3fe47278e7878053f Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 30 Apr 2026 12:08:36 +0200 Subject: [PATCH] Squash: ltx-2 --- CMakeLists.txt | 1 + examples/common/common.cpp | 126 ++ examples/common/common.h | 24 + ggml-patch.diff | 20 + include/stable-diffusion.h | 56 + run_ltx2.sh | 72 + src/backend_fit.hpp | 610 +++++++++ src/conditioner.hpp | 362 +++++ src/denoiser.hpp | 91 ++ src/diffusion_model.hpp | 89 ++ src/ggml_extend.hpp | 533 +++++++- src/llm.hpp | 498 ++++++- src/ltx.hpp | 1273 ++++++++++++++++++ src/ltx_connector.hpp | 632 +++++++++ src/ltx_rope.hpp | 350 +++++ src/ltxv.hpp | 110 +- src/ltxvae.hpp | 933 +++++++++++++ src/ltxvae_primitives.hpp | 212 +++ src/model.cpp | 25 +- src/model.h | 13 +- src/name_conversion.cpp | 35 + src/stable-diffusion.cpp | 1152 +++++++++++++++- src/tokenizers/gemma_tokenizer.cpp | 254 ++++ src/tokenizers/gemma_tokenizer.h | 50 + src/vae.hpp | 10 +- tests/ltx_parity/CMakeLists.txt | 134 ++ tests/ltx_parity/README.md | 36 + tests/ltx_parity/dump_av_block.py | 413 ++++++ tests/ltx_parity/dump_av_model.py | 312 +++++ tests/ltx_parity/dump_connector.py | 293 ++++ tests/ltx_parity/dump_gemma.py | 256 ++++ tests/ltx_parity/dump_reference.py | 623 +++++++++ tests/ltx_parity/dump_s2d.py | 176 +++ tests/ltx_parity/dump_vae.py | 341 +++++ tests/ltx_parity/test_attn_chain_parity.cpp | 165 +++ tests/ltx_parity/test_av_block_parity.cpp | 430 ++++++ tests/ltx_parity/test_av_block_smoke.cpp | 206 +++ tests/ltx_parity/test_av_model_parity.cpp | 331 +++++ tests/ltx_parity/test_connector_parity.cpp | 297 ++++ tests/ltx_parity/test_cont_parity.cpp | 129 ++ tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp | 352 +++++ tests/ltx_parity/test_gemma_parity.cpp | 287 ++++ tests/ltx_parity/test_gemma_tokenizer.cpp | 88 ++ tests/ltx_parity/test_ltx2_vae_roundtrip.cpp | 263 ++++ tests/ltx_parity/test_ltx_parity.cpp | 438 ++++++ tests/ltx_parity/test_mm_f32_parity.cpp | 240 ++++ tests/ltx_parity/test_s2d_primitives.cpp | 185 +++ tests/ltx_parity/test_softmax_parity.cpp | 120 ++ tests/ltx_parity/test_vae_parity.cpp | 378 ++++++ 49 files changed, 13871 insertions(+), 153 deletions(-) create mode 100644 ggml-patch.diff create mode 100755 run_ltx2.sh create mode 100644 src/backend_fit.hpp create mode 100644 src/ltx.hpp create mode 100644 src/ltx_connector.hpp create mode 100644 src/ltx_rope.hpp create mode 100644 src/ltxvae.hpp create mode 100644 src/ltxvae_primitives.hpp create mode 100644 src/tokenizers/gemma_tokenizer.cpp create mode 100644 src/tokenizers/gemma_tokenizer.h create mode 100644 tests/ltx_parity/CMakeLists.txt create mode 100644 tests/ltx_parity/README.md create mode 100644 tests/ltx_parity/dump_av_block.py create mode 100644 tests/ltx_parity/dump_av_model.py create mode 100644 tests/ltx_parity/dump_connector.py create mode 100644 tests/ltx_parity/dump_gemma.py create mode 100644 tests/ltx_parity/dump_reference.py create mode 100644 tests/ltx_parity/dump_s2d.py create mode 100644 tests/ltx_parity/dump_vae.py create mode 100644 tests/ltx_parity/test_attn_chain_parity.cpp create mode 100644 tests/ltx_parity/test_av_block_parity.cpp create mode 100644 tests/ltx_parity/test_av_block_smoke.cpp create mode 100644 tests/ltx_parity/test_av_model_parity.cpp create mode 100644 tests/ltx_parity/test_connector_parity.cpp create mode 100644 tests/ltx_parity/test_cont_parity.cpp create mode 100644 tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp create mode 100644 tests/ltx_parity/test_gemma_parity.cpp create mode 100644 tests/ltx_parity/test_gemma_tokenizer.cpp create mode 100644 tests/ltx_parity/test_ltx2_vae_roundtrip.cpp create mode 100644 tests/ltx_parity/test_ltx_parity.cpp create mode 100644 tests/ltx_parity/test_mm_f32_parity.cpp create mode 100644 tests/ltx_parity/test_s2d_primitives.cpp create mode 100644 tests/ltx_parity/test_softmax_parity.cpp create mode 100644 tests/ltx_parity/test_vae_parity.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a9fb1041..538b173b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) if (SD_BUILD_EXAMPLES) add_subdirectory(examples) + add_subdirectory(tests/ltx_parity) endif() set(SD_PUBLIC_HEADERS include/stable-diffusion.h) diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 0235c53de..1c92075e1 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -319,6 +319,10 @@ ArgOptions SDContextParams::get_options() { "--qwen2vl_vision", "alias of --llm_vision. Deprecated.", &llm_vision_path}, + {"", + "--gemma-tokenizer", + "path to Gemma's tokenizer.json (HF format). Required for LTX-2 text conditioning.", + &gemma_tokenizer_path}, {"", "--diffusion-model", "path to the standalone diffusion model", @@ -376,6 +380,25 @@ ArgOptions SDContextParams::get_options() { "--chroma-t5-mask-pad", "t5 mask pad size of chroma", &chroma_t5_mask_pad}, + {"", + "--fit-target", + "auto-fit: MiB of free memory to leave on each GPU (default: 512)", + &auto_fit_target_mb}, + {"", + "--fit-compute-reserve-dit", + "auto-fit: MiB reserved on the DiT's GPU for its compute buffer " + "(default: 2048, 0 keeps the built-in default)", + &auto_fit_compute_reserve_dit_mb}, + {"", + "--fit-compute-reserve-vae", + "auto-fit: MiB reserved on the VAE's GPU for its compute buffer " + "(default: 1024, 0 keeps the built-in default)", + &auto_fit_compute_reserve_vae_mb}, + {"", + "--fit-compute-reserve-cond", + "auto-fit: MiB reserved on the conditioner's GPU for its compute " + "buffer (default: 512, 0 keeps the built-in default)", + &auto_fit_compute_reserve_cond_mb}, }; options.float_options = {}; @@ -445,6 +468,27 @@ ArgOptions SDContextParams::get_options() { "--chroma-enable-t5-mask", "enable t5 mask for chroma", true, &chroma_use_t5_mask}, + {"", + "--auto-fit", + "automatically pick DiT/VAE/Conditioner device placements based on " + "free GPU memory (default ON; priority: DiT+compute > VAE > " + "Conditioner; overflow goes to CPU or DiT-params-offload mode)", + true, &auto_fit}, + {"", + "--no-auto-fit", + "disable auto-fit and use the explicit placement flags / env vars " + "(--clip-on-cpu, --vae-on-cpu, SD_CUDA_DEVICE*, etc.)", + false, &auto_fit}, + {"", + "--no-tensor-split", + "disable auto tensor split: keep the DiT on a single GPU even when " + "more than one CUDA device is detected. SD_CUDA_TENSOR_SPLIT env " + "still wins when set.", + false, &auto_tensor_split}, + {"", + "--fit-dry-run", + "auto-fit: print the computed plan and exit without loading models", + true, &auto_fit_dry_run}, }; auto on_type_arg = [&](int argc, const char** argv, int index) { @@ -517,6 +561,12 @@ ArgOptions SDContextParams::get_options() { return 1; }; + auto on_no_lazy_load_arg = [&](int /*argc*/, const char** /*argv*/, int /*index*/) { + lazy_load_dit = false; + lazy_load_cond = false; + return 0; + }; + options.manual_options = { {"", "--type", @@ -543,6 +593,12 @@ ArgOptions SDContextParams::get_options() { "but it usually offers faster inference speed and, in some cases, lower memory usage. " "The at_runtime mode, on the other hand, is exactly the opposite.", on_lora_apply_mode_arg}, + {"", + "--no-lazy-load", + "disable lazy load of DiT and conditioner-LLM weights (default ON). " + "Lazy load defers per-component allocation+read until first compute() " + "so the working set never holds all components resident.", + on_no_lazy_load_arg}, }; return options; @@ -638,6 +694,7 @@ std::string SDContextParams::to_string() const { << " t5xxl_path: \"" << t5xxl_path << "\",\n" << " llm_path: \"" << llm_path << "\",\n" << " llm_vision_path: \"" << llm_vision_path << "\",\n" + << " gemma_tokenizer_path: \"" << gemma_tokenizer_path << "\",\n" << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n" @@ -693,6 +750,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f t5xxl_path.c_str(), llm_path.c_str(), llm_vision_path.c_str(), + gemma_tokenizer_path.c_str(), diffusion_model_path.c_str(), high_noise_diffusion_model_path.c_str(), vae_path.c_str(), @@ -727,6 +785,15 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f chroma_use_t5_mask, chroma_t5_mask_pad, qwen_image_zero_cond_t, + auto_fit, + auto_fit_target_mb, + auto_fit_dry_run, + auto_fit_compute_reserve_dit_mb, + auto_fit_compute_reserve_vae_mb, + auto_fit_compute_reserve_cond_mb, + lazy_load_dit, + lazy_load_cond, + auto_tensor_split, }; return sd_ctx_params; } @@ -841,6 +908,18 @@ ArgOptions SDGenerationParams::get_options() { "--guidance", "distilled guidance scale for models with guidance input (default: 3.5)", &sample_params.guidance.distilled_guidance}, + {"", + "--rescale-scale", + "CFG-rescale to combat oversaturation (default: 0; LTX-2.3 expects 0.7)", + &sample_params.guidance.rescale_scale}, + {"", + "--stg-scale", + "Spatio-Temporal Guidance scale (default: 0; LTX-2.3 expects 1.0 with --stg-blocks [28])", + &sample_params.guidance.stg_scale}, + {"", + "--high-noise-stg-scale", + "(high noise) Spatio-Temporal Guidance scale (default: 0)", + &high_noise_sample_params.guidance.stg_scale}, {"", "--slg-scale", "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", @@ -1042,6 +1121,36 @@ ArgOptions SDGenerationParams::get_options() { return 1; }; + auto parse_int_list = [](std::string s, std::vector& out) -> bool { + if (s.empty()) return false; + if (s.front() == '[') s.erase(0, 1); + if (!s.empty() && s.back() == ']') s.pop_back(); + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(s.begin(), s.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tmp; + for (auto it = iter; it != end; ++it) { + std::string token = *it; + if (token.empty()) continue; + try { + tmp.push_back(std::stoi(token)); + } catch (const std::invalid_argument&) { + return false; + } + } + out = std::move(tmp); + return true; + }; + + auto on_stg_blocks_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) return -1; + return parse_int_list(argv[index], stg_blocks) ? 1 : -1; + }; + auto on_high_noise_stg_blocks_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) return -1; + return parse_int_list(argv[index], high_noise_stg_blocks) ? 1 : -1; + }; + auto on_sigmas_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { return -1; @@ -1209,6 +1318,14 @@ ArgOptions SDGenerationParams::get_options() { "--high-noise-skip-layers", "(high noise) layers to skip for SLG steps (default: [7,8,9])", on_high_noise_skip_layers_arg}, + {"", + "--stg-blocks", + "blocks for STG perturbed pass (LTX-2.3 default: [28]). Empty disables STG.", + on_stg_blocks_arg}, + {"", + "--high-noise-stg-blocks", + "(high noise) blocks for STG perturbed pass.", + on_high_noise_stg_blocks_arg}, {"-r", "--ref-image", "reference image for Flux Kontext models (can be used multiple times)", @@ -1932,6 +2049,10 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() { sample_params.guidance.slg.layer_count = skip_layers.size(); high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.empty() ? nullptr : high_noise_skip_layers.data(); high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); + sample_params.guidance.stg_blocks = stg_blocks.empty() ? nullptr : stg_blocks.data(); + sample_params.guidance.stg_blocks_count = stg_blocks.size(); + high_noise_sample_params.guidance.stg_blocks = high_noise_stg_blocks.empty() ? nullptr : high_noise_stg_blocks.data(); + high_noise_sample_params.guidance.stg_blocks_count = high_noise_stg_blocks.size(); sample_params.custom_sigmas = custom_sigmas.empty() ? nullptr : custom_sigmas.data(); sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); @@ -1991,6 +2112,10 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { sample_params.guidance.slg.layer_count = skip_layers.size(); high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.empty() ? nullptr : high_noise_skip_layers.data(); high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); + sample_params.guidance.stg_blocks = stg_blocks.empty() ? nullptr : stg_blocks.data(); + sample_params.guidance.stg_blocks_count = stg_blocks.size(); + high_noise_sample_params.guidance.stg_blocks = high_noise_stg_blocks.empty() ? nullptr : high_noise_stg_blocks.data(); + high_noise_sample_params.guidance.stg_blocks_count = high_noise_stg_blocks.size(); sample_params.custom_sigmas = custom_sigmas.empty() ? nullptr : custom_sigmas.data(); sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); @@ -2012,6 +2137,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { params.strength = strength; params.seed = seed; params.video_frames = video_frames; + params.fps = static_cast(fps); params.vace_strength = vace_strength; params.vae_tiling_params = vae_tiling_params; params.cache = cache_params; diff --git a/examples/common/common.h b/examples/common/common.h index 5afe89b34..41cdeb33e 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -90,6 +90,7 @@ struct SDContextParams { std::string t5xxl_path; std::string llm_path; std::string llm_vision_path; + std::string gemma_tokenizer_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; @@ -127,6 +128,25 @@ struct SDContextParams { bool qwen_image_zero_cond_t = false; + // Auto-fit: pick DiT/VAE/Conditioner device placements from free GPU memory. + // Default ON; pass --no-auto-fit to opt out. + bool auto_fit = true; + int auto_fit_target_mb = 512; + bool auto_fit_dry_run = false; + int auto_fit_compute_reserve_dit_mb = 0; // 0 = use header default + int auto_fit_compute_reserve_vae_mb = 0; + int auto_fit_compute_reserve_cond_mb = 0; + + // Lazy load: defer DiT and conditioner-LLM weight allocation+read until + // the first compute(). Default ON; pass --no-lazy-load to opt out. + bool lazy_load_dit = true; + bool lazy_load_cond = true; + + // Auto tensor split: when >1 CUDA device is detected, split DiT row-wise + // across all GPUs by free-VRAM ratio. Default ON; pass --no-tensor-split + // to opt out. SD_CUDA_TENSOR_SPLIT env still wins when set. + bool auto_tensor_split = true; + prediction_t prediction = PREDICTION_COUNT; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; @@ -168,6 +188,10 @@ struct SDGenerationParams { sd_sample_params_t high_noise_sample_params; std::vector skip_layers = {7, 8, 9}; std::vector high_noise_skip_layers = {7, 8, 9}; + // STG (Spatio-Temporal Guidance) blocks (LTX-2.3 default: [28]). Empty means + // STG disabled even if --stg-scale is set. + std::vector stg_blocks = {}; + std::vector high_noise_stg_blocks = {}; std::vector custom_sigmas; diff --git a/ggml-patch.diff b/ggml-patch.diff new file mode 100644 index 000000000..747cc80c9 --- /dev/null +++ b/ggml-patch.diff @@ -0,0 +1,20 @@ +diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu +index cc80eb3f..dd79d9c0 100644 +--- a/src/ggml-cuda/ggml-cuda.cu ++++ b/src/ggml-cuda/ggml-cuda.cu +@@ -865,7 +865,14 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff + } + + static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +- GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported ++ // Views: the view's storage comes from view_src, so we don't allocate ++ // anything on the split buffer for it. The split buffer's per-device ++ // shards aren't applicable to a view — the view simply reuses view_src's ++ // memory layout. Sched will route any op that consumes this view through ++ // view_src's backend. ++ if (tensor->view_src != nullptr) { ++ return GGML_STATUS_SUCCESS; ++ } + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); + + ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index a99b10450..23d0e3064 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -76,6 +76,7 @@ enum prediction_t { FLOW_PRED, FLUX_FLOW_PRED, FLUX2_FLOW_PRED, + LTX2_FLOW_PRED, PREDICTION_COUNT }; @@ -169,6 +170,11 @@ typedef struct { const char* t5xxl_path; const char* llm_path; const char* llm_vision_path; + // Path to a HuggingFace-format tokenizer.json file. Currently only read by the + // LTX-2 Gemma 3 conditioner, which requires Gemma's tokenizer for BPE + metaspace + // encoding of prompts. If empty for LTX-2, the conditioner aborts with a clear + // message. Non-LTX-2 pipelines ignore this field. + const char* gemma_tokenizer_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; const char* vae_path; @@ -203,6 +209,36 @@ typedef struct { bool chroma_use_t5_mask; int chroma_t5_mask_pad; bool qwen_image_zero_cond_t; + + // Auto-fit: pick DiT/VAE/Conditioner devices based on free GPU memory. + // When `auto_fit` is true (default), the CLI placement overrides (env vars, + // keep_*_on_cpu) are ignored and the plan is computed automatically. + // `auto_fit_target_mb` is the memory to leave free per GPU (default 512). + // `auto_fit_dry_run` prints the plan and aborts init before loading. + // `auto_fit_compute_reserve_{dit,vae,cond}_mb` let the user tune the + // per-component compute-buffer reserve; 0 means use the built-in default. + bool auto_fit; + int auto_fit_target_mb; + bool auto_fit_dry_run; + int auto_fit_compute_reserve_dit_mb; + int auto_fit_compute_reserve_vae_mb; + int auto_fit_compute_reserve_cond_mb; + + // Lazy load: defer DiT and conditioner-LLM weight allocation+read until + // the first compute() call, so the working set never holds all components + // resident simultaneously. Required when sum-of-components exceeds combined + // VRAM (LTX-2 + Gemma + VAE all at once won't fit on a 24 GB rig). + // Defaults to true. Env-var overrides (SD_LAZY_LOAD_DIT/SD_LAZY_LOAD_COND) + // still work and force-enable when set; they cannot disable. + bool lazy_load_dit; + bool lazy_load_cond; + + // Auto tensor split: when more than one CUDA device is detected and the + // user did NOT explicitly set SD_CUDA_TENSOR_SPLIT, automatically split + // the DiT row-wise across all available GPUs with ratios proportional to + // each device's free VRAM. Defaults to true. Set false to keep the DiT + // on a single GPU (with auto-fit choosing which one). + bool auto_tensor_split; } sd_ctx_params_t; typedef struct { @@ -225,6 +261,20 @@ typedef struct { float img_cfg; float distilled_guidance; sd_slg_params_t slg; + // CFG-rescale (LTX-2.3 default 0.7, no effect when 0). After CFG mixing, + // pred is rescaled toward cond's std to combat oversaturation: + // factor = cond.std() / pred.std() + // factor = rescale_scale * factor + (1 - rescale_scale) + // pred *= factor + float rescale_scale; + // Spatio-Temporal Guidance (LTX-2.3 default stg_scale=1.0, stg_blocks=[28]). + // Adds a third forward pass with self-attention skipped on the listed + // transformer blocks; the resulting "weakened" prediction is mixed into + // the guided pred: pred += stg_scale * (cond - perturbed). + // No effect when stg_scale==0 or stg_blocks_count==0. + float stg_scale; + int* stg_blocks; + size_t stg_blocks_count; } sd_guidance_params_t; typedef struct { @@ -332,6 +382,12 @@ typedef struct { float strength; int64_t seed; int video_frames; + // Output video fps. Carried through to models that use it for temporal + // positional embeddings — LTX-2's RoPE divides the time axis by fps + // (ltx_core/tools.py::VideoLatentTools.create_initial_state), so the + // default 24 on LTXRunner silently produces wrong positions at any + // other target fps. 0 means "don't override runner default". + float fps; float vace_strength; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; diff --git a/run_ltx2.sh b/run_ltx2.sh new file mode 100755 index 000000000..92dde69a6 --- /dev/null +++ b/run_ltx2.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Run LTX-2.3 22B-dev video generation with the current local model paths. +# All placement / lazy-load / tensor-split is auto-detected (defaults to ON), +# so this script only needs the prompt and output path. +# +# Usage: +# ./run_ltx2.sh "" [output.webm] +# +# Override knobs by editing the variables below, or by exporting them before +# the call: +# STEPS=20 SEED=99 ./run_ltx2.sh "a sunset" out.webm + +set -euo pipefail + +# --- Models (current local paths) ----------------------------------------- +DIT="/media/ilintar/D_SSD/models/ltx-2/ltx-2.3-22b-dev-Q6_K.gguf" +LLM="/media/ilintar/D_SSD/models/ltx-2/gemma-3-12b-it-UD-Q8_K_XL.gguf" +VAE="/media/ilintar/D_SSD/models/ltx-2/ltx-2.3-22b-dev_video_vae.safetensors" +EMB="/media/ilintar/D_SSD/models/ltx-2/ltx-2.3-22b-dev_embeddings_connectors.safetensors" +TOK="/media/ilintar/D_SSD/models/ltx-2/gemma_tokenizer.json" + +# --- Core video parameters (override via env or edit here) ---------------- +WIDTH="${WIDTH:-480}" +HEIGHT="${HEIGHT:-320}" +FRAMES="${FRAMES:-25}" +FPS="${FPS:-25}" +STEPS="${STEPS:-30}" +SEED="${SEED:-7}" + +# --- Sampling / guidance -------------------------------------------------- +CFG_SCALE="${CFG_SCALE:-3.0}" +RESCALE_SCALE="${RESCALE_SCALE:-0.7}" +STG_SCALE="${STG_SCALE:-1.0}" +STG_BLOCKS="${STG_BLOCKS:-[28]}" + +# --- Args ----------------------------------------------------------------- +PROMPT="${1:-a cinematic photograph of a sunset over the ocean}" +OUTPUT="${2:-output.webm}" +NEG_PROMPT="${NEG_PROMPT:-}" + +# --- Resolve sd-cli path -------------------------------------------------- +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SD_CLI="${SD_CLI:-${SCRIPT_DIR}/build/bin/sd-cli}" +if [[ ! -x "$SD_CLI" ]]; then + echo "sd-cli not found at $SD_CLI — build it first (cmake --build build -j)" >&2 + exit 1 +fi + +echo "running LTX-2 vid_gen → $OUTPUT" +echo " ${WIDTH}x${HEIGHT}, ${FRAMES} frames @ ${FPS} fps, ${STEPS} steps, seed=${SEED}" +echo " cfg=${CFG_SCALE}, rescale=${RESCALE_SCALE}, stg=${STG_SCALE} blocks=${STG_BLOCKS}" + +NEG_ARGS=() +if [[ -n "$NEG_PROMPT" ]]; then + NEG_ARGS=(-n "$NEG_PROMPT") +fi + +SD_QUIET_UNKNOWN_TENSORS=1 exec "$SD_CLI" -M vid_gen \ + --diffusion-model "$DIT" \ + --llm "$LLM" \ + --vae "$VAE" \ + -m "$EMB" \ + --gemma-tokenizer "$TOK" \ + -W "$WIDTH" -H "$HEIGHT" --video-frames "$FRAMES" --fps "$FPS" \ + --steps "$STEPS" --seed "$SEED" \ + --cfg-scale "$CFG_SCALE" --rescale-scale "$RESCALE_SCALE" \ + --stg-scale "$STG_SCALE" --stg-blocks "$STG_BLOCKS" \ + --diffusion-fa \ + --mmap \ + -p "$PROMPT" \ + "${NEG_ARGS[@]}" \ + -o "$OUTPUT" diff --git a/src/backend_fit.hpp b/src/backend_fit.hpp new file mode 100644 index 000000000..9f92f07f9 --- /dev/null +++ b/src/backend_fit.hpp @@ -0,0 +1,610 @@ +#ifndef __SD_BACKEND_FIT_HPP__ +#define __SD_BACKEND_FIT_HPP__ + +// Auto-fit algorithm for distributing DiT, VAE, and conditioner (LLM + +// connector) across available GPU devices and system RAM. +// +// Inspired by llama.cpp's common_fit_params (tools/fit-params), but much +// coarser: sd.cpp treats each of {DiT, VAE, Conditioner} as a single atomic +// unit that lives entirely on one device (plus the DiT's compute buffer on +// the same GPU). There is no per-layer tensor_buft_overrides mechanism in +// sd.cpp today — the existing `offload_params_to_cpu` knob is the only way to +// "split" a model (it keeps params in RAM and streams them to the runtime +// backend per forward pass). +// +// Placement priority: DiT + compute buffer → VAE → Conditioner (+connector). +// Overflow falls back to CPU (or GPU_OFFLOAD_PARAMS for DiT). + +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#if defined(SD_USE_VULKAN) +#include "ggml-backend.h" +#endif + +#include "model.h" +#include "util.h" + +namespace backend_fit { + +constexpr int64_t MiB = 1024 * 1024; +constexpr int DEVICE_ID_CPU = -1; + +enum class ComponentKind { + DIT, + VAE, + CONDITIONER, // LLM + connector (share a backend) +}; + +enum class Placement { + CPU, + GPU, + GPU_OFFLOAD_PARAMS, // params in RAM, compute on GPU (DiT-only) + GPU_TENSOR_SPLIT, // params row-split across all GPUs (multi-GPU only) +}; + +struct Component { + ComponentKind kind; + std::string name; + int64_t params_bytes = 0; // weight memory for this component + int64_t compute_bytes = 0; // reserved compute buffer on the chosen device + bool supports_offload = false; // true only for DiT +}; + +struct Device { + int id = DEVICE_ID_CPU; + std::string name; + std::string description; + int64_t free_bytes = 0; + int64_t total_bytes = 0; +}; + +struct Decision { + ComponentKind kind; + std::string name; + Placement placement = Placement::CPU; + int device_id = DEVICE_ID_CPU; + int64_t on_device_bytes = 0; // contribution to device_id's device memory + int64_t on_host_bytes = 0; // contribution to host RAM +}; + +struct Plan { + std::vector decisions; + std::map device_bytes; // device_id -> bytes used + int64_t host_bytes = 0; + bool any_changes = false; // true if a non-default placement was chosen +}; + +// Defaults chosen to leave enough headroom for typical diffusion/video models. +// Configurable via the CLI (--fit-compute-reserve-* in MiB). +struct ComputeReserves { + int64_t dit_bytes = int64_t(2048) * MiB; // video DiT compute buffer + int64_t vae_bytes = int64_t(1024) * MiB; // video VAE compute buffer + int64_t conditioner_bytes = int64_t(512) * MiB; // LLM + connector combined +}; + +// --- Classification ------------------------------------------------------- + +// Classify a tensor name into a ComponentKind. Returns false if the tensor is +// unused / not a primary weight we should count. +inline bool classify_tensor(const std::string& name, ComponentKind& out) { + // Connector lives inside `model.diffusion_model.*` by prefix but runs on + // the conditioner's backend, so it gets charged to CONDITIONER. + auto contains = [&](const char* s) { return name.find(s) != std::string::npos; }; + + // LTX-2 specific: the checkpoint carries audio-to-video branch weights + // (`.audio_*`, `.audio_to_video_*`, `.video_to_audio_*`, `audio_patchify_*`, + // `audio_scale_shift_*`, `audio_prompt_*`) that the video-only LTX2 + // diffusion module does NOT wire in. They're logged as "unknown tensor" + // warnings at load time and skipped. Excluding them here keeps the DiT + // params estimate honest (~9 GB) instead of including ~4 GB of audio + // tensors that never touch the GPU. + if (contains(".audio_") || + contains("audio_patchify") || + contains("audio_aggregate") || + contains("audio_scale_shift") || + contains("audio_prompt") || + contains("a2v_ca_audio") || + contains("a2v_ca_video")) { + return false; + } + + if (contains("embeddings_connector") || + contains("aggregate_embed") || + contains("text_embedding_projection")) { + out = ComponentKind::CONDITIONER; + return true; + } + + if (contains("model.diffusion_model.") || contains("unet.")) { + out = ComponentKind::DIT; + return true; + } + + if (contains("first_stage_model.") || + name.rfind("vae.", 0) == 0 || + name.rfind("tae.", 0) == 0) { + out = ComponentKind::VAE; + return true; + } + + if (contains("text_encoders") || + contains("cond_stage_model") || + contains("te.text_model.") || + contains("conditioner") || + name.rfind("text_encoder.", 0) == 0) { + out = ComponentKind::CONDITIONER; + return true; + } + + return false; +} + +// --- Memory estimation ---------------------------------------------------- + +// Sum params bytes per component using the same alignment padding and +// dtype-conversion rules as ModelLoader::get_params_mem_size. +inline std::vector estimate_components(ModelLoader& loader, + ggml_type override_wtype, + int64_t alignment, + const ComputeReserves& reserves) { + auto& storage = loader.get_tensor_storage_map(); + + int64_t bytes[3] = {0, 0, 0}; // DIT, VAE, CONDITIONER + int counts[3] = {0, 0, 0}; + + for (auto& [name, ts_const] : storage) { + // Work on a copy so we can apply the dtype override without mutating. + TensorStorage ts = ts_const; + if (is_unused_tensor(ts.name)) { + continue; + } + + ComponentKind k; + if (!classify_tensor(ts.name, k)) { + continue; + } + + if (override_wtype != GGML_TYPE_COUNT && + loader.tensor_should_be_converted(ts, override_wtype)) { + ts.type = override_wtype; + } else if (ts.expected_type != GGML_TYPE_COUNT && ts.expected_type != ts.type) { + // Honor per-tensor retypes (e.g. LTX-2 Gemma → q8_0 fix in + // stable-diffusion.cpp) when computing component size. + ts.type = ts.expected_type; + } + + int idx = int(k); + bytes[idx] += ts.nbytes() + alignment; + counts[idx] += 1; + } + + std::vector out; + out.reserve(3); + + out.push_back(Component{ + ComponentKind::DIT, "DiT", + bytes[int(ComponentKind::DIT)], reserves.dit_bytes, + /*supports_offload=*/true, + }); + out.push_back(Component{ + ComponentKind::VAE, "VAE", + bytes[int(ComponentKind::VAE)], reserves.vae_bytes, + /*supports_offload=*/false, + }); + out.push_back(Component{ + ComponentKind::CONDITIONER, "Conditioner", + bytes[int(ComponentKind::CONDITIONER)], reserves.conditioner_bytes, + /*supports_offload=*/true, // Gemma/etc. can stream params to GPU per encode + }); + + (void)counts; + return out; +} + +// --- Device enumeration --------------------------------------------------- + +inline std::vector enumerate_gpu_devices() { + std::vector out; + +#if defined(SD_USE_CUDA) + int count = ggml_backend_cuda_get_device_count(); + for (int i = 0; i < count; i++) { + Device d; + d.id = i; + char desc_buf[256] = {0}; + ggml_backend_cuda_get_device_description(i, desc_buf, sizeof(desc_buf)); + d.description = desc_buf; + d.name = "CUDA" + std::to_string(i); + size_t free_b = 0, total_b = 0; + ggml_backend_cuda_get_device_memory(i, &free_b, &total_b); + d.free_bytes = int64_t(free_b); + d.total_bytes = int64_t(total_b); + out.push_back(d); + } +#elif defined(SD_USE_VULKAN) + int count = ggml_backend_vk_get_device_count(); + for (int i = 0; i < count; i++) { + Device d; + d.id = i; + d.name = "Vulkan" + std::to_string(i); + // Vulkan backend does not expose a direct free-memory API; enumerate + // via ggml_backend_dev so we can reuse ggml_backend_dev_memory. + ggml_backend_dev_t dev = nullptr; + for (size_t j = 0; j < ggml_backend_dev_count(); j++) { + ggml_backend_dev_t candidate = ggml_backend_dev_get(j); + if (ggml_backend_dev_type(candidate) == GGML_BACKEND_DEVICE_TYPE_GPU && + std::string(ggml_backend_dev_name(candidate)).find("Vulkan") != std::string::npos) { + if (int(j) == i) { dev = candidate; break; } + } + } + if (dev) { + d.description = ggml_backend_dev_description(dev); + size_t free_b = 0, total_b = 0; + ggml_backend_dev_memory(dev, &free_b, &total_b); + d.free_bytes = int64_t(free_b); + d.total_bytes = int64_t(total_b); + } + out.push_back(d); + } +#endif + + return out; +} + +// --- Core algorithm ------------------------------------------------------- + +// Peak VRAM per GPU is computed from two contributions: +// 1. `nonoffload_sum` — sum of params of every non-offload component on +// that GPU. These live on VRAM from LOAD through their free-after-use +// point, overlapping during the load window. +// 2. `max_active_footprint` — the largest per-phase compute footprint, +// where a non-offload component's phase contributes just its compute +// buffer, and an offload component's phase contributes params+compute +// (its runtime buffer is full-size while active, freed by +// `free_compute_buffer_immediately=true` between phases). +// peak = nonoffload_sum + max_active_footprint. This is conservative: it +// assumes the load-time accumulation overlaps with an active compute phase +// of the worst-case component. In practice load finishes before any compute +// starts so this over-counts by max_active_footprint during load — safe. +// Compute the per-GPU split share of a tensor-split component, weighted by +// each device's free VRAM. Returns a vector of size devices.size() with each +// device's portion of params_bytes (the same ratio applies to compute_bytes). +inline std::vector tensor_split_ratios(const std::vector& devices) { + double total = 0.0; + std::vector ratios(devices.size(), 0.0); + for (size_t g = 0; g < devices.size(); g++) { + ratios[g] = std::max(0, devices[g].free_bytes); + total += ratios[g]; + } + if (total <= 0.0) { + // Fallback: equal split. + std::fill(ratios.begin(), ratios.end(), 1.0 / std::max(1, devices.size())); + return ratios; + } + for (auto& r : ratios) r /= total; + return ratios; +} + +// Peak GPU memory per device. Components time-share VRAM at runtime +// (free_params_immediately frees params between phases), so peak per device +// is the MAX of any single component's resident footprint on that device, +// not the SUM. Footprint = params + compute (for whichever placement mode +// applies). +inline int64_t gpu_peak(int gpu_idx, + const std::vector& pl, + const std::vector& dev, + const std::vector& components, + const std::vector& devices = {}) { + int64_t peak = 0; + + std::vector split_ratios; + bool any_split = false; + for (size_t i = 0; i < components.size(); i++) { + if (pl[i] == Placement::GPU_TENSOR_SPLIT) { any_split = true; break; } + } + if (any_split && !devices.empty()) { + split_ratios = tensor_split_ratios(devices); + } + + for (size_t i = 0; i < components.size(); i++) { + const Component& c = components[i]; + int64_t footprint = 0; + if (pl[i] == Placement::GPU) { + if (dev[i] != gpu_idx) continue; + footprint = c.params_bytes + c.compute_bytes; + } else if (pl[i] == Placement::GPU_OFFLOAD_PARAMS) { + if (dev[i] != gpu_idx) continue; + footprint = c.params_bytes + c.compute_bytes; + } else if (pl[i] == Placement::GPU_TENSOR_SPLIT) { + if (gpu_idx < 0 || size_t(gpu_idx) >= split_ratios.size()) continue; + double r = split_ratios[gpu_idx]; + int64_t share = int64_t(double(c.params_bytes + c.compute_bytes) * r); + footprint = share; + } + peak = std::max(peak, footprint); + } + return peak; +} + +inline Plan compute_plan(const std::vector& components, + const std::vector& devices, + int64_t margin_bytes, + bool allow_tensor_split = false) { + // Enumeration approach: for each component we have up to (1 + 2 * nGPU) + // placement options — CPU, or non-offload / offload on each GPU (offload + // only when the component supports it). We try all combinations, filter + // infeasible ones (any GPU's computed peak exceeds its free-margin cap), + // and pick the combination with the best score. + // + // Score rewards GPU placement (heavily), non-offload over offload + // (avoids per-step stream cost), and GPU diversity (use multiple GPUs + // when possible instead of packing onto one). Priority runtime hot + // components are weighted higher: DiT >> Conditioner > VAE. + const size_t nC = components.size(); + const size_t nG = devices.size(); + + std::vector cap(nG, 0); + for (size_t g = 0; g < nG; g++) { + cap[g] = devices[g].free_bytes - margin_bytes; + if (cap[g] < 0) cap[g] = 0; + } + + struct OptionSlot { + Placement placement; + int device_idx; // index into devices, or -1 for CPU + }; + + auto build_options = [&](const Component& c) { + std::vector opts; + for (size_t g = 0; g < nG; g++) { + opts.push_back({Placement::GPU, int(g)}); + if (c.supports_offload) { + opts.push_back({Placement::GPU_OFFLOAD_PARAMS, int(g)}); + } + } + // Tensor split: only for the heavy components (DiT, Conditioner) and + // only when there is more than one GPU. The VAE is too light to be + // worth splitting and isn't currently wired for split anyway. + if (allow_tensor_split && nG >= 2 && + (c.kind == ComponentKind::DIT || c.kind == ComponentKind::CONDITIONER)) { + opts.push_back({Placement::GPU_TENSOR_SPLIT, -1}); + } + opts.push_back({Placement::CPU, -1}); + return opts; + }; + + std::vector> options; + options.reserve(nC); + for (const Component& c : components) { + options.push_back(build_options(c)); + } + + auto priority_weight = [](ComponentKind k) -> int { + switch (k) { + case ComponentKind::DIT: return 300; // runs N times per generation + case ComponentKind::CONDITIONER: return 120; // one large forward per prompt + case ComponentKind::VAE: return 60; // one decode per generation + } + return 1; + }; + + auto score = [&](const std::vector& pl, + const std::vector& dev) { + int64_t s = 0; + std::set gpus_used; + for (size_t i = 0; i < nC; i++) { + const int pw = priority_weight(components[i].kind); + if (pl[i] == Placement::GPU) { + s += 10 * pw; + gpus_used.insert(dev[i]); + } else if (pl[i] == Placement::GPU_OFFLOAD_PARAMS) { + s += 5 * pw; // still on GPU but with per-step stream overhead + gpus_used.insert(dev[i]); + } else if (pl[i] == Placement::GPU_TENSOR_SPLIT) { + // Better than CPU but worse than fitting on a single GPU + // (cross-GPU traffic per layer). Use 7 * pw so it's preferred + // over OFFLOAD_PARAMS only when the latter would not fit. + s += 7 * pw; + for (size_t g = 0; g < devices.size(); g++) gpus_used.insert(int(g)); + } else { + s -= 10 * pw; + } + } + s += 2 * int64_t(gpus_used.size()); // mild spread bonus + return s; + }; + + std::vector idx(nC, 0); + std::vector best_pl; + std::vector best_dev; + int64_t best_score = std::numeric_limits::min(); + bool found_any = false; + + // Iterate the cartesian product of options. + while (true) { + std::vector pl(nC); + std::vector dev(nC); + for (size_t i = 0; i < nC; i++) { + pl[i] = options[i][idx[i]].placement; + dev[i] = options[i][idx[i]].device_idx; + } + // Feasibility check: peak on each GPU vs cap. + bool feasible = true; + for (size_t g = 0; g < nG; g++) { + if (gpu_peak(int(g), pl, dev, components, devices) > cap[g]) { + feasible = false; + break; + } + } + if (feasible) { + int64_t sc = score(pl, dev); + if (sc > best_score) { + best_score = sc; + best_pl = pl; + best_dev = dev; + found_any = true; + } + } + + // Advance mixed-radix counter. + size_t pos = 0; + while (pos < nC) { + idx[pos]++; + if (idx[pos] < options[pos].size()) break; + idx[pos] = 0; + pos++; + } + if (pos >= nC) break; + } + + Plan plan; + if (!found_any) { + // Degenerate: no feasible solution (even all-CPU must be feasible by + // construction; but guard anyway). Fall back to CPU for everything. + best_pl.assign(nC, Placement::CPU); + best_dev.assign(nC, -1); + } + + std::vector split_ratios; + for (size_t i = 0; i < nC; i++) { + if (best_pl[i] == Placement::GPU_TENSOR_SPLIT) { + split_ratios = tensor_split_ratios(devices); + break; + } + } + + for (size_t i = 0; i < nC; i++) { + const Component& c = components[i]; + Decision d; + d.kind = c.kind; + d.name = c.name; + d.placement = best_pl[i]; + if (best_pl[i] == Placement::CPU) { + d.device_id = DEVICE_ID_CPU; + d.on_host_bytes = c.params_bytes + c.compute_bytes; + plan.any_changes = true; + } else if (best_pl[i] == Placement::GPU_TENSOR_SPLIT) { + // device_id == DEVICE_ID_CPU is a sentinel meaning "all GPUs" + // for split decisions. on_device_bytes records the largest + // per-GPU share (peak). on_host_bytes stays 0. + d.device_id = DEVICE_ID_CPU; + int64_t max_share = 0; + for (size_t g = 0; g < nG; g++) { + int64_t share = int64_t(double(c.params_bytes + c.compute_bytes) * + split_ratios[g]); + max_share = std::max(max_share, share); + } + d.on_device_bytes = max_share; + plan.any_changes = true; + } else { + d.device_id = devices[best_dev[i]].id; + if (best_pl[i] == Placement::GPU) { + d.on_device_bytes = c.params_bytes + c.compute_bytes; + } else { // GPU_OFFLOAD_PARAMS + d.on_device_bytes = c.params_bytes + c.compute_bytes; // peak during its compute + d.on_host_bytes = c.params_bytes; + plan.any_changes = true; + } + } + plan.decisions.push_back(d); + plan.host_bytes += d.on_host_bytes; + } + + // Report per-device peak using the same model as feasibility check. + for (size_t g = 0; g < nG; g++) { + plan.device_bytes[devices[g].id] = gpu_peak(int(g), best_pl, best_dev, components, devices); + } + return plan; +} + +inline const char* placement_str(Placement p) { + switch (p) { + case Placement::CPU: return "CPU"; + case Placement::GPU: return "GPU"; + case Placement::GPU_OFFLOAD_PARAMS: return "GPU(params->RAM)"; + case Placement::GPU_TENSOR_SPLIT: return "GPU(tensor-split)"; + } + return "?"; +} + +inline void print_plan(const Plan& plan, + const std::vector& components, + const std::vector& devices, + int64_t margin_bytes) { + LOG_INFO("auto-fit plan (margin=%lld MiB per GPU):", + (long long)(margin_bytes / MiB)); + LOG_INFO(" available devices:"); + if (devices.empty()) { + LOG_INFO(" (no GPU devices detected — all components will run on CPU)"); + } + for (const Device& d : devices) { + LOG_INFO(" %-8s %-32s free %6lld / %6lld MiB", + d.name.c_str(), d.description.c_str(), + (long long)(d.free_bytes / MiB), + (long long)(d.total_bytes / MiB)); + } + LOG_INFO(" components:"); + for (const Component& c : components) { + LOG_INFO(" %-12s params %6lld MiB, compute reserve %6lld MiB", + c.name.c_str(), + (long long)(c.params_bytes / MiB), + (long long)(c.compute_bytes / MiB)); + } + LOG_INFO(" decisions:"); + for (const Decision& d : plan.decisions) { + if (d.placement == Placement::CPU) { + LOG_INFO(" %-12s -> CPU (RAM %lld MiB)", + d.name.c_str(), (long long)(d.on_host_bytes / MiB)); + } else if (d.placement == Placement::GPU) { + LOG_INFO(" %-12s -> GPU %d (VRAM %lld MiB)", + d.name.c_str(), d.device_id, + (long long)(d.on_device_bytes / MiB)); + } else if (d.placement == Placement::GPU_TENSOR_SPLIT) { + LOG_INFO(" %-12s -> tensor-split (VRAM peak %lld MiB on largest-share GPU)", + d.name.c_str(), + (long long)(d.on_device_bytes / MiB)); + } else { + LOG_INFO(" %-12s -> GPU %d (params RAM) (VRAM %lld MiB, RAM %lld MiB)", + d.name.c_str(), d.device_id, + (long long)(d.on_device_bytes / MiB), + (long long)(d.on_host_bytes / MiB)); + } + } + LOG_INFO(" projected per-device peak (MAX of assigned components, " + "since free_params_immediately lets components time-share VRAM):"); + for (const Device& d : devices) { + int64_t peak = 0; + auto it = plan.device_bytes.find(d.id); + if (it != plan.device_bytes.end()) peak = it->second; + const int64_t remaining = d.free_bytes - peak; + LOG_INFO(" %-8s peak %6lld / %6lld MiB free (remaining %lld MiB)", + d.name.c_str(), + (long long)(peak / MiB), + (long long)(d.free_bytes / MiB), + (long long)(remaining / MiB)); + } + LOG_INFO(" %-8s host RAM additional %lld MiB", "CPU", + (long long)(plan.host_bytes / MiB)); +} + +// Convenience: look up the decision for a specific component. +inline const Decision* find_decision(const Plan& plan, ComponentKind kind) { + for (const Decision& d : plan.decisions) { + if (d.kind == kind) return &d; + } + return nullptr; +} + +} // namespace backend_fit + +#endif // __SD_BACKEND_FIT_HPP__ diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 9f4d45524..b7b00d115 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -5,8 +5,10 @@ #include "clip.hpp" #include "llm.hpp" +#include "ltx_connector.hpp" #include "t5.hpp" #include "tensor_ggml.hpp" +#include "tokenizers/gemma_tokenizer.h" struct SDCondition { sd::Tensor c_crossattn; @@ -94,6 +96,18 @@ struct Conditioner { virtual std::string remove_trigger_from_prompt(const std::string& prompt) { GGML_ABORT("Not implemented yet!"); } + // Lazy-load hook on the LLM/Gemma side. Default no-op for conditioners + // whose state is too small to need it. Overridden by LTX2GemmaConditioner. + virtual void set_llm_lazy_load(std::function /*fn*/) {} + // Tensor-map split for the lazy LLM path: populate `tensors` with ONLY the + // LLM's tensors (so the lazy callback knows what to load) or with + // EVERYTHING EXCEPT the LLM (so the global eager load skips them). Default + // no-op / delegates to get_param_tensors for conditioners without an LLM + // split point. + virtual void get_llm_param_tensors(std::map& /*tensors*/) {} + virtual void get_non_llm_param_tensors(std::map& tensors) { + get_param_tensors(tensors); + } }; // ldm.modules.encoders.modules.FrozenCLIPEmbedder @@ -1958,4 +1972,352 @@ struct LLMEmbedder : public Conditioner { } }; +// LTX-2 conditioner: Gemma 3 text encoder → feature extractor → 1D connector → +// DiT cross-attention context. Supports both V1 (19B) and V2 (22B) feature +// extractor variants, auto-detected from the tensor map. +// +// Key prefixes (native LTX-2 checkpoint layout, no name-conversion applied): +// text_encoder.model.* Gemma weights +// text_embedding_projection.aggregate_embed.* V1 FeatureExtractorV1 (19B) +// text_embedding_projection.video_aggregate_embed.* V2 FeatureExtractorV2 video branch (22B) +// text_embedding_projection.audio_aggregate_embed.* V2 audio branch (22B, currently unused) +// model.diffusion_model.embeddings_connector.* V1 Embeddings1DConnector (19B) +// model.diffusion_model.video_embeddings_connector.* V2 video connector (22B) +// model.diffusion_model.caption_projection.* V1 PixArt caption_projection (on DiT) +// (V2 has no caption_projection — feature +// extractor already outputs DiT's inner_dim) +// +// If neither V1 nor V2 connector weights are present (e.g. Gemma-only test +// checkpoints), the conditioner falls back to returning the final post-norm +// hidden state — the same cheap path we had before Phase 9 landed. +struct LTX2GemmaConditioner : public Conditioner { + std::shared_ptr llm; + std::shared_ptr tokenizer; + std::shared_ptr connector_runner; + std::string prefix; + std::string tokenizer_path; + int64_t gemma_hidden_size = 0; + int gemma_num_hidden_layers = 0; + // True when using the V2 (22B) feature extractor; used by get_learned_condition + // to pick the right CPU normalization path. + bool use_v2_feature_extractor = false; + + LTX2GemmaConditioner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix = "text_encoder", + const std::string tokenizer_path = "", + const std::string feat_ext_prefix = "text_embedding_projection", + const std::string connector_prefix_arg = "") + : prefix(prefix), tokenizer_path(tokenizer_path) { + llm = std::make_shared(LLM::LLMArch::GEMMA3, + backend, + offload_params_to_cpu, + tensor_storage_map, + prefix, + /*enable_vision=*/false); + gemma_hidden_size = llm->params.hidden_size; + gemma_num_hidden_layers = static_cast(llm->params.num_layers); + + if (!tokenizer_path.empty()) { + tokenizer = std::make_shared(); + if (!tokenizer->load_from_file(tokenizer_path)) { + LOG_WARN("LTX2GemmaConditioner: failed to load Gemma tokenizer from '%s'", tokenizer_path.c_str()); + tokenizer.reset(); + } + } + + // Auto-detect V1 vs V2 feature extractor + connector prefix variant. + // V2 (22B): text_embedding_projection.video_aggregate_embed.{weight,bias} + + // model.diffusion_model.video_embeddings_connector.* + // V1 (19B): text_embedding_projection.aggregate_embed.weight + + // model.diffusion_model.embeddings_connector.* + // `connector_prefix_arg` is honored when non-empty, otherwise we probe both. + const std::string& feat_ext_pre = feat_ext_prefix; + + auto agg_v1_it = tensor_storage_map.find(feat_ext_pre + ".aggregate_embed.weight"); + auto agg_v2_it = tensor_storage_map.find(feat_ext_pre + ".video_aggregate_embed.weight"); + + std::string connector_pre; + LTXConnector::FeatureExtractorVersion fe_version = LTXConnector::FeatureExtractorVersion::V1; + int64_t flat_dim = 0; + int64_t inner_dim = 0; + + if (agg_v2_it != tensor_storage_map.end()) { + fe_version = LTXConnector::FeatureExtractorVersion::V2; + flat_dim = agg_v2_it->second.ne[0]; + inner_dim = agg_v2_it->second.ne[1]; + use_v2_feature_extractor = true; + connector_pre = connector_prefix_arg.empty() + ? "model.diffusion_model.video_embeddings_connector" + : connector_prefix_arg; + } else if (agg_v1_it != tensor_storage_map.end()) { + fe_version = LTXConnector::FeatureExtractorVersion::V1; + flat_dim = agg_v1_it->second.ne[0]; + inner_dim = agg_v1_it->second.ne[1]; + connector_pre = connector_prefix_arg.empty() + ? "model.diffusion_model.embeddings_connector" + : connector_prefix_arg; + } else { + LOG_INFO("LTX2GemmaConditioner: no feature_extractor weights found — falling back to " + "last_hidden_state pass-through (Gemma-only mode)"); + return; + } + + auto conn0_it = tensor_storage_map.find(connector_pre + ".transformer_1d_blocks.0.attn1.to_q.weight"); + if (conn0_it == tensor_storage_map.end()) { + LOG_WARN("LTX2GemmaConditioner: feature_extractor weights present but connector at '%s' is missing; " + "falling back to last_hidden_state", + connector_pre.c_str()); + return; + } + if (conn0_it->second.ne[1] != inner_dim) { + LOG_WARN("LTX2GemmaConditioner: connector to_q out_features=%lld does not match " + "feature_extractor inner_dim=%lld; skipping connector.", + (long long)conn0_it->second.ne[1], (long long)inner_dim); + return; + } + + // Count connector layers by probing to_q presence. + int num_layers = 0; + while (tensor_storage_map.find(connector_pre + ".transformer_1d_blocks." + + std::to_string(num_layers) + ".attn1.to_q.weight") != + tensor_storage_map.end()) { + num_layers++; + } + + // num_registers from learnable_registers.ne (ne[0]=inner_dim, ne[1]=num_registers). + int num_registers = 0; + auto reg_it = tensor_storage_map.find(connector_pre + ".learnable_registers"); + if (reg_it != tensor_storage_map.end() && reg_it->second.n_dims >= 2) { + num_registers = static_cast(reg_it->second.ne[1]); + } + + // Detect gated attention inside the connector (V2 / 22B has this). + bool apply_gated = tensor_storage_map.find( + connector_pre + ".transformer_1d_blocks.0.attn1.to_gate_logits.weight") != + tensor_storage_map.end(); + + // LTX-2 fixes head_dim=128 across both variants. + int head_dim = 128; + int num_heads = static_cast(inner_dim / head_dim); + + // We do NOT include caption_projection here — V1 has it on the DiT side, + // V2 has none. Pass source_dim=Gemma hidden so V2's sqrt(target/source) + // rescale is applied correctly. + connector_runner = std::make_shared( + backend, offload_params_to_cpu, + flat_dim, num_heads, head_dim, num_layers, num_registers, + /*caption_channels=*/0, /*caption_hidden=*/0, /*caption_out=*/0, + /*theta=*/10000.0f, /*max_pos=*/std::vector{1}, + tensor_storage_map, + /*include_caption_projection=*/false, + feat_ext_pre, connector_pre, /*caption_proj_prefix=*/"", + fe_version, /*source_dim=*/gemma_hidden_size, apply_gated); + LOG_INFO("LTX2GemmaConditioner: wired %s connector (flat_dim=%lld inner_dim=%lld " + "num_layers=%d num_registers=%d gated=%d)", + fe_version == LTXConnector::FeatureExtractorVersion::V2 ? "V2" : "V1", + (long long)flat_dim, (long long)inner_dim, num_layers, num_registers, + apply_gated ? 1 : 0); + } + + void get_param_tensors(std::map& tensors) override { + llm->get_param_tensors(tensors, prefix); + if (connector_runner) { + connector_runner->get_param_tensors(tensors); + } + } + void get_llm_param_tensors(std::map& tensors) override { + if (llm) llm->get_param_tensors(tensors, prefix); + } + void get_non_llm_param_tensors(std::map& tensors) override { + if (connector_runner) connector_runner->get_param_tensors(tensors); + } + void alloc_params_buffer() override { + llm->alloc_params_buffer(); + if (connector_runner) connector_runner->alloc_params_buffer(); + } + void free_params_buffer() override { + llm->free_params_buffer(); + if (connector_runner) connector_runner->free_params_buffer(); + } + size_t get_params_buffer_size() override { + size_t s = llm->get_params_buffer_size(); + if (connector_runner) s += connector_runner->get_params_buffer_size(); + return s; + } + void set_flash_attention_enabled(bool enabled) override { + llm->set_flash_attention_enabled(enabled); + if (connector_runner) connector_runner->set_flash_attention_enabled(enabled); + } + + void set_llm_lazy_load(std::function fn) override { + if (llm) llm->set_lazy_load(std::move(fn)); + } + + SDCondition get_learned_condition(int n_threads, + const ConditionerParams& p) override { + if (!tokenizer) { + LOG_ERROR("LTX2GemmaConditioner: no tokenizer loaded. Construct the conditioner " + "with a path to Gemma's tokenizer.json."); + GGML_ABORT("Gemma tokenizer missing"); + } + // HuggingFace Gemma tokenizer always prepends ; we replicate that here + // so the encoder sees the same sequence the Python reference does. + std::vector real_ids = tokenizer->tokenize(p.text, nullptr, /*padding=*/false); + real_ids.insert(real_ids.begin(), tokenizer->BOS_TOKEN_ID); + const int64_t T_real = static_cast(real_ids.size()); + LOG_DEBUG("LTX2GemmaConditioner: tokenized prompt '%s' -> %lld real tokens", + p.text.c_str(), (long long)T_real); + sd::Tensor empty_mask; + + if (!connector_runner) { + // No connector weights: behave like before Phase 9 landed (no padding). + sd::Tensor ids_tensor({T_real, 1}); + for (int64_t i = 0; i < T_real; ++i) ids_tensor.data()[i] = real_ids[i]; + auto last_hidden = llm->compute(n_threads, ids_tensor, empty_mask, {}, {}); + SDCondition cond; + cond.c_crossattn = last_hidden; + return cond; + } + + // Python LTX-2 tokenizer pads to max_length=1024 with padding_side="left" + // and pad_token = EOS: + // ltx_core/text_encoders/gemma/tokenizer.py:21-24 (padding_side="left", + // pad_token=EOS) and ltx_core/text_encoders/gemma/encoders/base_encoder.py:182 + // (`LTXVGemmaTokenizer(tokenizer_root, 1024)`). + // Gemma processes the full max_length, and the connector then sees a + // max_length-long sequence with learnable_registers tiled max_length/num_reg + // times (8× on the 22B V2 path, where num_reg=128). Padding only to + // num_registers produces the wrong Gemma RoPE positions for the real tokens + // and cuts the DiT cross-attention context by the same factor; both regress + // output quality from recognisable subjects to colored-blob textures. + const int num_registers = connector_runner->num_registers; + // The 1024-pad-with-register-tile path (Python ref's + // `Embeddings1DConnector._replace_padded_with_learnable_registers`) produces + // wrong subjects on LTX-2.3 22B (e.g. "fiery sword" → vintage farm scene), + // while feeding T_real real tokens through the connector blocks produces + // correctly prompt-locked output. Until we localize the discrepancy in the + // tile path, default to the no-pad/no-tile path. Set LTX2_PAD=1 to opt back + // into the Python-ref tile path (e.g. for parity dumps). + const bool no_pad = std::getenv("LTX2_PAD") == nullptr; + const int64_t max_length = no_pad ? T_real : 1024; + int64_t T_pad = 0; + int64_t T = T_real; + if (T_real < max_length) { + T_pad = max_length - T_real; + T = max_length; + } else if (T_real > max_length) { + // Prompt already exceeds max_length — truncate to match tokenizer + // behaviour (`truncation=True` in LTXVGemmaTokenizer). + LOG_WARN("LTX2GemmaConditioner: prompt tokenised to %lld >= max_length=%lld; truncating.", + (long long)T_real, (long long)max_length); + real_ids.resize(static_cast(max_length)); + T = max_length; + T_pad = 0; + } + sd::Tensor input_ids({T, 1}); + for (int64_t i = 0; i < T_pad; ++i) input_ids.data()[i] = tokenizer->EOS_TOKEN_ID; + const int64_t real_to_write = std::min(T_real, max_length); + for (int64_t i = 0; i < real_to_write; ++i) input_ids.data()[T_pad + i] = real_ids[i]; + // In no_pad mode the connector skips the register tile (target_seq_len = T_real). + // In padded mode (default) we tile to 1024 to match Python's Embeddings1DConnector. + if (no_pad) { + connector_runner->set_target_seq_len(static_cast(T_real)); + } else { + GGML_ASSERT(num_registers == 0 || max_length % num_registers == 0); + connector_runner->set_target_seq_len(static_cast(max_length)); + } + + // 1. Gemma: compute all N+1 hidden states on the padded sequence. Passing + // T_pad as pad_count tells build_graph to mask out positions [0, T_pad) + // as keys for any real query — without this the real tokens at [T_pad, T) + // attend across all 1024-T_real left-padded EOS tokens and lose subject + // information. HF transformers does this implicitly when given an + // attention_mask=[0..0,1..1] alongside left-padded input_ids. + // Layout returned by compute_all_hidden_states: ne [N+1, H, T, B] = + // PyTorch [B, T, H, N+1] (stack of per-layer hidden states). + auto stacked = llm->compute_all_hidden_states(n_threads, input_ids, empty_mask, + /*pad_count=*/static_cast(T_pad)); + const int64_t B = 1; + const int64_t D = gemma_hidden_size; + const int64_t L = gemma_num_hidden_layers + 1; + GGML_ASSERT(stacked.numel() == L * D * T * B); + + if (const char* dump_path = std::getenv("SD_DUMP_COND_STACKED")) { + FILE* f = std::fopen(dump_path, "wb"); + if (f) { + std::fwrite(stacked.data(), sizeof(float), stacked.numel(), f); + std::fclose(f); + LOG_INFO("SD_DUMP_COND_STACKED: wrote %ld floats to %s (ne=[%ld,%ld,%ld,%ld])", + (long)stacked.numel(), dump_path, (long)L, (long)D, (long)T, (long)B); + } + } + + // 2. CPU normalize → [B, T, D*L]. seq_lens=[T_real_eff] + left-padding tells + // the normalizer to zero out the pad positions (which live at [0, T_pad)). + // T_real_eff caps at max_length to handle the truncated-prompt branch above. + const int64_t T_real_eff = std::min(T_real, max_length); + std::vector seq_lens(B, static_cast(T_real_eff)); + sd::Tensor normed({D * L, T, B}); + if (use_v2_feature_extractor) { + LTXConnector::feature_extractor_normalize_v2( + stacked.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + } else { + LTXConnector::feature_extractor_normalize( + stacked.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + } + + // Python's Embeddings1DConnector._replace_padded_with_learnable_registers moves + // the real-token rows from [T_pad, T) to the START of the sequence and replaces + // the now-empty tail with learnable_registers[T_real:]. Equivalent CPU-side shift: + // after normalize, [0,T_pad) holds zeros (masked pad), [T_pad,T) holds real. + // Slide the real rows down to [0,T_real) and re-zero the tail — the connector + // runner then tiles/slices learnable_registers[T_real:max_length] and concats. + if (T_pad > 0) { + const int64_t flat_dim = D * L; + sd::Tensor reals({flat_dim, T_real_eff, B}); + for (int64_t b = 0; b < B; ++b) { + std::memcpy(reals.data() + b * T_real_eff * flat_dim, + normed.data() + b * T * flat_dim + T_pad * flat_dim, + static_cast(T_real_eff * flat_dim) * sizeof(float)); + } + normed = std::move(reals); + } + + // 3. Run connector. Stage selection: + // 0 = stop right after the feature_extractor projection (no register + // tiling, no transformer_1d_blocks, no final rms_norm). This is + // what PR #1459 / #1463 do — they treat the embeddings_connector + // as a refinement layer and skip it. + // 3 = full Embeddings1DConnector (Python-faithful: register-tile to + // max_length, run all blocks, final rms_norm). + // Defaulting to stage=3 (faithful). Override with LTX2_COND_STAGE=N for + // diagnostics. + int cond_stage = 3; + if (const char* cs = std::getenv("LTX2_COND_STAGE")) { + cond_stage = std::atoi(cs); + } + auto context = connector_runner->compute(n_threads, normed, cond_stage); + + if (const char* dump_path = std::getenv("SD_DUMP_COND_CONTEXT")) { + FILE* f = std::fopen(dump_path, "wb"); + if (f) { + std::fwrite(context.data(), sizeof(float), context.numel(), f); + std::fclose(f); + LOG_INFO("SD_DUMP_COND_CONTEXT: wrote %ld floats to %s", + (long)context.numel(), dump_path); + } + } + + SDCondition cond; + cond.c_crossattn = context; + return cond; + } +}; + #endif diff --git a/src/denoiser.hpp b/src/denoiser.hpp index a6e81d597..4613bffd2 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -720,6 +720,97 @@ struct FluxFlowDenoiser : public DiscreteFlowDenoiser { } }; +// LTX-2 flow-match denoiser. +// +// Reference: /devel/tools/diffusion/LTX-2/packages/ltx-core/src/ltx_core/components/schedulers.py +// +// Key differences from FluxFlowDenoiser: +// - sigma_to_t(σ) = σ * 1000 (Flux passes raw σ; LTX's TransformerArgsPreprocessor scales by +// 1000 in Python, but we externalise that to the denoiser so the +// DiT's AdaLayerNormSingle doesn't double-multiply). +// - Token-count-dependent shift: mu = linear_interp(tokens, 1024→0.95, 4096→2.05), log-space. +// - Terminal stretch: after flux_time_shift, rescale non-zero sigmas so the last non-zero lands +// at `terminal` (default 0.1). This is what the LTX-2 distilled LoRAs expect. +// - scheduler_t is ignored — LTX2Scheduler is fixed; a non-default value would give wrong +// behaviour for the trained weights. +struct LTX2FlowDenoiser : public DiscreteFlowDenoiser { + static constexpr int BASE_SHIFT_ANCHOR = 1024; + static constexpr int MAX_SHIFT_ANCHOR = 4096; + + // Log-space shift anchors; get exponentiated in compute_mu. + float max_shift = 2.05f; + float base_shift = 0.95f; + float terminal = 0.1f; + bool stretch = true; + + LTX2FlowDenoiser() = default; + + // Compute the shift `mu` used inside flux_time_shift. Python: + // mm = (max_shift - base_shift) / (MAX_ANCHOR - BASE_ANCHOR) + // b = base_shift - mm * BASE_ANCHOR + // sigma_shift = tokens * mm + b + float compute_mu(int tokens) const { + float mm = (max_shift - base_shift) / static_cast(MAX_SHIFT_ANCHOR - BASE_SHIFT_ANCHOR); + float b = base_shift - mm * static_cast(BASE_SHIFT_ANCHOR); + return static_cast(tokens) * mm + b; + } + + // t_to_sigma uses the base-shift mapping as a best-effort inverse. The real inverse depends on + // the terminal stretch, which needs the full schedule context — sampling never actually inverts + // t_to_sigma at arbitrary points, so this is only here to satisfy the virtual interface. + float t_to_sigma(float t) override { + return flux_time_shift(base_shift, 1.0f, (t + 1.0f) / TIMESTEPS); + } + + std::vector get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion /*version*/) override { + if (scheduler_type != DISCRETE_SCHEDULER) { + LOG_WARN("LTX2FlowDenoiser: ignoring scheduler_type=%d; LTX-2 uses a fixed schedule", + static_cast(scheduler_type)); + } + + int tokens = image_seq_len > 0 ? image_seq_len : MAX_SHIFT_ANCHOR; + float mu = compute_mu(tokens); + float exp_mu = std::exp(mu); + LOG_DEBUG("LTX2FlowDenoiser: tokens=%d mu=%.4f stretch=%d terminal=%.3f", + tokens, mu, stretch ? 1 : 0, terminal); + + std::vector sigmas(n + 1); + // linspace(1.0, 0.0, n+1) then apply flux_time_shift (power=1) to non-zero entries. + for (uint32_t i = 0; i <= n; ++i) { + float t = 1.0f - static_cast(i) / static_cast(n); + if (t <= 0.0f) { + sigmas[i] = 0.0f; + } else { + sigmas[i] = exp_mu / (exp_mu + (1.0f / t - 1.0f)); + } + } + + // Terminal stretch: rescale `1 - σ` so that the last non-zero σ lands at `terminal`. + if (stretch) { + int last_nonzero = -1; + for (int i = static_cast(n); i >= 0; --i) { + if (sigmas[i] > 0.0f) { + last_nonzero = i; + break; + } + } + if (last_nonzero > 0) { + float one_minus_last = 1.0f - sigmas[last_nonzero]; + float scale_factor = one_minus_last / (1.0f - terminal); + if (scale_factor > 0.0f) { + for (uint32_t i = 0; i <= n; ++i) { + if (sigmas[i] > 0.0f) { + sigmas[i] = 1.0f - (1.0f - sigmas[i]) / scale_factor; + } + } + } + } + } + + return sigmas; + } +}; + struct Flux2FlowDenoiser : public FluxFlowDenoiser { Flux2FlowDenoiser() = default; diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index c0a2a11c0..8c16c0e77 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -5,6 +5,7 @@ #include "anima.hpp" #include "ernie_image.hpp" #include "flux.hpp" +#include "ltx.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" #include "tensor_ggml.hpp" @@ -29,6 +30,10 @@ struct DiffusionParams { const sd::Tensor* vace_context = nullptr; float vace_strength = 1.f; const std::vector* skip_layers = nullptr; + // STG (Spatio-Temporal Guidance): block indices whose self-attention is bypassed + // during the perturbed-pass forward. Currently only LTX-2 honors this; other + // models ignore the field. Empty/nullptr means no perturbation. + const std::vector* stg_skip_blocks = nullptr; }; template @@ -50,6 +55,14 @@ struct DiffusionModel { virtual int64_t get_adm_in_channels() = 0; virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; + // Overridden only by models whose spatial / temporal embeddings depend on the + // output fps (currently LTX-2). Image-only models ignore the value. + virtual void set_fps(float fps) {} + // Lazy-load hook — register a callback on the inner GGMLRunner. Default + // no-op for models that don't yet support sequential lazy loading; the + // user-facing OOM scenarios ( Q6_K LTX-2 + Q8_K_XL Gemma on 24GB combined + // VRAM) only need this for LTXDiffusionModel right now. + virtual void set_lazy_load(std::function /*fn*/) {} }; struct UNetModel : public DiffusionModel { @@ -517,6 +530,82 @@ struct ZImageModel : public DiffusionModel { } }; +struct LTXDiffusionModel : public DiffusionModel { + std::string prefix; + LTX::LTXRunner ltx; + + LTXDiffusionModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_LTX2) + : prefix(prefix), ltx(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { + } + + std::string get_desc() override { + return ltx.get_desc(); + } + + void alloc_params_buffer() override { + ltx.alloc_params_buffer(); + } + + void free_params_buffer() override { + ltx.free_params_buffer(); + } + + void free_compute_buffer() override { + ltx.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + ltx.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return ltx.get_params_buffer_size(); + } + + void set_lazy_load(std::function fn) override { + ltx.set_lazy_load(std::move(fn)); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + ltx.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 0; + } + + void set_flash_attention_enabled(bool enabled) override { + ltx.set_flash_attention_enabled(enabled); + } + + void set_circular_axes(bool circular_x, bool circular_y) override { + ltx.set_circular_axes(circular_x, circular_y); + } + + void set_fps(float fps) override { + if (fps > 0.f) { + ltx.set_fps(fps); + } + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + static const sd::Tensor empty; + return ltx.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context), + empty, + diffusion_params.stg_skip_blocks); + } +}; + struct ErnieImageModel : public DiffusionModel { std::string prefix; ErnieImage::ErnieImageRunner ernie_image; diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cbd..660f3bc94 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1307,6 +1307,29 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* return out; } +// Optional tap collector for debugging the attention block. The Gemma parity +// test sets this to capture intermediate kq/softmax/kqv values. Inplace ops +// in the manual attention path overwrite the tap tensors' names to "node_X" +// or "(view)" so the names don't survive collection — left here as a knob +// for future experiments where the manual path is rewritten to use +// non-inplace variants. +inline std::vector* g_attn_layer0_taps = nullptr; +inline int g_attn_tap_count = 0; +inline int g_attn_tap_budget = 7; +inline ggml_tensor* attn_tap(ggml_context* ctx, ggml_tensor* t, const char* name) { + if (g_attn_layer0_taps == nullptr) return t; + if (g_attn_tap_count >= g_attn_tap_budget) return t; + ggml_tensor* keep = ggml_cont(ctx, t); + ggml_set_output(keep); + std::string full = std::string("DBG:") + name; + ggml_set_name(keep, full.c_str()); + fprintf(stderr, "[attn_tap] pushed name='%s' (intended='%s') tensor=%p\n", + ggml_get_name(keep), full.c_str(), (void*)keep); + g_attn_layer0_taps->push_back(keep); + ++g_attn_tap_count; + return keep; +} + // q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] // k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head] // v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] @@ -1439,16 +1462,27 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, // } v = ggml_ext_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k] + v = attn_tap(ctx, v, "_attn_v"); auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_scale_inplace(ctx, kq, scale); + kq = attn_tap(ctx, kq, "_attn_kq_raw"); + // Non-inplace variants here so the parity-test taps above survive + // collection. ggml's allocator reuses inplace-op buffers and overwrites + // the cont tensor's name, leaving us with "node_42 / (view)" entries + // that the dumper can't match. + kq = ggml_scale(ctx, kq, scale); + kq = attn_tap(ctx, kq, "_attn_kq_scaled"); if (mask) { - kq = ggml_add_inplace(ctx, kq, mask); + kq = ggml_add(ctx, kq, mask); + kq = attn_tap(ctx, kq, "_attn_kq_masked"); } - kq = ggml_soft_max_inplace(ctx, kq); + kq = ggml_soft_max(ctx, kq); + kq = attn_tap(ctx, kq, "_attn_softmax"); kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] + ggml_mul_mat_set_prec(kqv, GGML_PREC_F32); + kqv = attn_tap(ctx, kqv, "_attn_kqv_raw"); kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] @@ -1684,6 +1718,59 @@ struct GGMLRunnerContext { std::shared_ptr weight_adapter = nullptr; }; +// Forward declaration — defined near support_get_rows() below. Used by +// GGMLRunner's ctor to publish its params backend so Embedding::init_params +// can pick the right get_rows allowlist without plumbing backend through +// every init_params override. +__STATIC_INLINE__ ggml_backend_t& current_params_backend(); + +// --------------------------------------------------------------------------- +// Multi-GPU tensor split support +// --------------------------------------------------------------------------- +// A GGMLRunner can opt into "tensor split" mode where matmul weight tensors +// are distributed row-wise across multiple GPUs (CUDA-only today). This is +// triggered by setting `g_pending_multi_backend_spec()` to a populated +// MultiBackendSpec just before constructing the runner. The runner ctor +// consumes and clears the pointer. +// +// In multi-backend mode the runner: +// - allocates a CUDA split-buffer-type for matmul-eligible weights +// - allocates a regular per-device buffer for non-matmul weights (norms, +// biases, embeddings, small projections, non-contiguous weights) +// - replaces ggml_gallocr with ggml_backend_sched_t for graph compute +// - the CUDA backend internally dispatches split-tensor matmul work to +// all participating devices (via tensor->extra->data_device[id]) +// +// Limitations: +// - tensor split is incompatible with offload_params_to_cpu (offload +// assumes a single contiguous params buffer) +// - non-CUDA builds silently disable the request and fall back to single- +// backend mode +struct MultiBackendSpec { + // Extra GPU backends *in addition to* the runner's main runtime_backend. + // Together {runtime_backend} ∪ extra_backends forms the GPU set. + std::vector extra_backends; + + // Per-device row-split ratios (length = total CUDA device count). + // If empty, the CUDA backend's default tensor_split (free-VRAM + // proportions) is used. + std::vector tensor_split; + + // CUDA device id of the "main" GPU (the one the split buft is anchored + // to). Typically 0. + int main_device = 0; + + // CPU backend appended last to the sched for unsupported-op fallback. + // Optional — may be nullptr to skip CPU fallback. + ggml_backend_t cpu_fallback = nullptr; +}; + +// Thread-local pending spec consumed by the next GGMLRunner ctor. +__STATIC_INLINE__ MultiBackendSpec*& g_pending_multi_backend_spec() { + thread_local MultiBackendSpec* spec = nullptr; + return spec; +} + struct GGMLRunner { protected: typedef std::function get_graph_cb_t; @@ -1703,6 +1790,29 @@ struct GGMLRunner { ggml_context* compute_ctx = nullptr; ggml_gallocr* compute_allocr = nullptr; + // --- multi-backend (tensor split) state --- + // multi_backend_mode toggles the multi-backend code paths in + // alloc_params_buffer / compute(). False (default) keeps the existing + // single-backend gallocr path with zero behavioural change. + bool multi_backend_mode = false; + std::vector extra_backends; // additional GPU backends + std::vector tensor_split_ratios; + int main_device = 0; + ggml_backend_t cpu_fallback_backend = nullptr; + ggml_backend_buffer_type_t split_buft = nullptr; + ggml_backend_buffer_t split_params_buffer = nullptr; + ggml_backend_sched_t sched = nullptr; + bool sched_reserved = false; + + // --- Lazy load (sequential per-component loading) --- + // When set, the runner defers alloc_params_buffer + reading bytes from disk + // until the first compute() call. After free_params_buffer, the next + // compute() re-allocates and re-loads. Lets multi-component pipelines whose + // total params exceed combined VRAM run by loading at most one component + // at a time. Eager mode (default, lazy_load_fn unset) preserves existing + // behaviour: caller calls alloc_params_buffer() + populates tensors. + std::function lazy_load_fn = nullptr; + std::shared_ptr weight_adapter = nullptr; std::vector one_vec = {1.f}; @@ -1838,7 +1948,52 @@ struct GGMLRunner { return gf; } + // Build the sched on first use. Sched lifetime ties to the runner; reset + // happens on every compute() to clear prev-graph allocations. + bool ensure_sched() { + if (sched != nullptr) return true; + + std::vector backends; + backends.reserve(1 + extra_backends.size() + (cpu_fallback_backend ? 1 : 0)); + backends.push_back(runtime_backend); + for (auto* b : extra_backends) backends.push_back(b); + if (cpu_fallback_backend != nullptr) backends.push_back(cpu_fallback_backend); + + sched = ggml_backend_sched_new(backends.data(), + /*bufts=*/nullptr, + (int)backends.size(), + MAX_GRAPH_SIZE, + /*parallel=*/false, + /*op_offload=*/false); + if (sched == nullptr) { + LOG_ERROR("%s: failed to create backend sched", get_desc().c_str()); + return false; + } + return true; + } + bool alloc_compute_buffer(get_graph_cb_t get_graph) { + if (multi_backend_mode) { + if (sched_reserved) return true; + if (!ensure_sched()) return false; + reset_compute_ctx(); + ggml_cgraph* gf = get_compute_graph(get_graph); + backend_tensor_data_map.clear(); + if (!ggml_backend_sched_reserve(sched, gf)) { + LOG_ERROR("%s: sched reserve failed", get_desc().c_str()); + return false; + } + sched_reserved = true; + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); i++) { + ggml_backend_t b = ggml_backend_sched_get_backend(sched, i); + size_t s = ggml_backend_sched_get_buffer_size(sched, b); + LOG_DEBUG("%s sched buf[%d] %s = %.2f MB", + get_desc().c_str(), i, ggml_backend_name(b), + s / (1024.f * 1024.f)); + } + return true; + } + if (compute_allocr != nullptr) { return true; } @@ -1997,12 +2152,53 @@ struct GGMLRunner { GGMLRunner(ggml_backend_t backend, bool offload_params_to_cpu = false) : runtime_backend(backend) { + // Consume any pending multi-backend (tensor-split) spec set by the + // caller via g_pending_multi_backend_spec(). + MultiBackendSpec* pending = g_pending_multi_backend_spec(); + if (pending != nullptr) { + g_pending_multi_backend_spec() = nullptr; // consume +#ifdef SD_USE_CUDA + multi_backend_mode = true; + extra_backends = pending->extra_backends; + tensor_split_ratios = pending->tensor_split; + main_device = pending->main_device; + cpu_fallback_backend = pending->cpu_fallback; + // Build the CUDA split buft. tensor_split_ratios may be empty, + // in which case the CUDA backend uses its default free-VRAM + // proportional split. + split_buft = ggml_backend_cuda_split_buffer_type( + main_device, + tensor_split_ratios.empty() ? nullptr : tensor_split_ratios.data()); + if (split_buft == nullptr) { + LOG_WARN("multi-backend: split buft init failed; falling back to single-backend mode"); + multi_backend_mode = false; + extra_backends.clear(); + cpu_fallback_backend = nullptr; + } + if (multi_backend_mode && offload_params_to_cpu) { + LOG_WARN("multi-backend: tensor split is incompatible with offload_params_to_cpu; " + "ignoring offload"); + offload_params_to_cpu = false; + } +#else + (void)pending; + LOG_WARN("multi-backend: tensor split requested but build lacks CUDA; ignoring"); +#endif + } + alloc_params_ctx(); if (!ggml_backend_is_cpu(runtime_backend) && offload_params_to_cpu) { params_backend = ggml_backend_cpu_init(); } else { params_backend = runtime_backend; } + // Publish the RUNTIME backend (not params) so block init_params() can + // reach it from support_get_rows() without plumbing backend through + // every init_params override. Runtime matters for get_rows: when + // offload_params_to_cpu is true, params live on CPU but the actual + // get_rows executes on the runtime backend (the GPU), which requires + // the weight dtype to be CUDA-supported. + current_params_backend() = runtime_backend; } virtual ~GGMLRunner() { @@ -2014,6 +2210,14 @@ struct GGMLRunner { ggml_backend_free(params_backend); } free_cache_ctx_and_buffer(); + // The extra GPU backends and the CPU fallback are owned by the caller + // (see init_multi_backend_spec in stable-diffusion.cpp). The split + // buft is a process-cached singleton owned by the CUDA backend, so + // we don't free it here. Sched is per-runner and owned here. + if (sched != nullptr) { + ggml_backend_sched_free(sched); + sched = nullptr; + } } virtual GGMLRunnerContext get_context() { @@ -2033,13 +2237,108 @@ struct GGMLRunner { alloc_compute_ctx(); } - bool alloc_params_buffer() { + // Heuristic: which weight tensors should be row-split across GPUs? + // Constraints from the CUDA split buft impl: + // - must be contiguous (init_tensor asserts on this) + // - row-splitting only pays off for sizable matmul weights, so we + // restrict to rank-2 tensors with both dims >= 256 + // 1D biases, norms, embeddings, and small projections fall back to the + // main GPU's regular per-device buft. + static bool is_split_eligible(const ggml_tensor* t) { + if (!ggml_is_contiguous(t)) return false; + if (ggml_n_dims(t) != 2) return false; + if (t->ne[0] < 256 || t->ne[1] < 256) return false; + return true; + } + + bool alloc_params_buffer_multi() { +#ifdef SD_USE_CUDA + // Walk params_ctx, classify tensors into split-eligible vs. main, then + // allocate two buffers (split buft + per-device buft) and bind tensors + // via ggml_tallocr. + ggml_backend_buffer_type_t main_buft = ggml_backend_get_default_buffer_type(runtime_backend); + const size_t main_align = ggml_backend_buft_get_alignment(main_buft); + const size_t split_align = ggml_backend_buft_get_alignment(split_buft); + + size_t main_size = 0, split_size = 0; + size_t main_count = 0, split_count = 0; + for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != nullptr; + t = ggml_get_next_tensor(params_ctx, t)) { + if (is_split_eligible(t)) { + size_t s = ggml_backend_buft_get_alloc_size(split_buft, t); + split_size += GGML_PAD(s, split_align); + split_count++; + } else { + size_t s = ggml_backend_buft_get_alloc_size(main_buft, t); + main_size += GGML_PAD(s, main_align); + main_count++; + } + } + + if (main_size > 0) { + params_buffer = ggml_backend_buft_alloc_buffer(main_buft, main_size); + if (params_buffer == nullptr) { + LOG_ERROR("%s alloc main params buffer failed (%.1f MB)", + get_desc().c_str(), main_size / (1024.f * 1024.f)); + return false; + } + } + if (split_size > 0) { + split_params_buffer = ggml_backend_buft_alloc_buffer(split_buft, split_size); + if (split_params_buffer == nullptr) { + LOG_ERROR("%s alloc split params buffer failed (%.1f MB)", + get_desc().c_str(), split_size / (1024.f * 1024.f)); + return false; + } + } + + ggml_tallocr main_alloc{}; + ggml_tallocr split_alloc{}; + if (params_buffer != nullptr) main_alloc = ggml_tallocr_new(params_buffer); + if (split_params_buffer != nullptr) split_alloc = ggml_tallocr_new(split_params_buffer); + + for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != nullptr; + t = ggml_get_next_tensor(params_ctx, t)) { + ggml_status st = is_split_eligible(t) + ? ggml_tallocr_alloc(&split_alloc, t) + : ggml_tallocr_alloc(&main_alloc, t); + if (st != GGML_STATUS_SUCCESS) { + LOG_ERROR("%s tallocr_alloc failed for tensor %s", + get_desc().c_str(), t->name); + return false; + } + } + + if (params_buffer != nullptr) { + ggml_backend_buffer_set_usage(params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + if (split_params_buffer != nullptr) { + ggml_backend_buffer_set_usage(split_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + LOG_INFO("%s tensor-split params: main %.1f MB (%zu tensors), split %.1f MB (%zu tensors)", + get_desc().c_str(), + main_size / (1024.f * 1024.f), main_count, + split_size / (1024.f * 1024.f), split_count); + return true; +#else + LOG_ERROR("alloc_params_buffer_multi called without CUDA support"); + return false; +#endif + } + + // Internal allocator — always materializes the params buffer. Used by + // both the eager `alloc_params_buffer` path and the lazy + // `ensure_params_loaded` path; the latter must bypass the lazy-skip. + bool do_alloc_params_buffer() { + if (multi_backend_mode && split_buft != nullptr) { + return alloc_params_buffer_multi(); + } size_t num_tensors = ggml_tensor_num(params_ctx); params_buffer = ggml_backend_alloc_ctx_tensors(params_ctx, params_backend); if (params_buffer == nullptr) { LOG_ERROR("%s alloc params backend buffer failed, num_tensors = %i", - get_desc().c_str(), - num_tensors); + get_desc().c_str(), num_tensors); return false; } size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer); @@ -2048,21 +2347,111 @@ struct GGMLRunner { params_buffer_size / (1024.f * 1024.f), ggml_backend_is_cpu(params_backend) ? "RAM" : "VRAM", num_tensors); + if (params_buffer_size >= size_t(1) << 30) { + std::map> per_type; + for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != nullptr; + t = ggml_get_next_tensor(params_ctx, t)) { + auto& entry = per_type[t->type]; + entry.first += ggml_nbytes(t); + entry.second += 1; + } + std::string breakdown; + for (const auto& kv : per_type) { + char buf[96]; + std::snprintf(buf, sizeof(buf), "%s %zu/%6.1fMB ", + ggml_type_name(kv.first), kv.second.second, + kv.second.first / (1024.f * 1024.f)); + breakdown += buf; + } + LOG_INFO("%s param breakdown: %s", get_desc().c_str(), breakdown.c_str()); + } return true; } + bool alloc_params_buffer() { + // Lazy mode: defer alloc + disk read until first compute(). The runner + // is configured but holds no VRAM until then. Caller still goes through + // the same alloc + get_param_tensors flow at init; the load step inside + // ModelLoader::load_tensors silently skips this runner's tensors (their + // ->data is null because no buffer is allocated yet — the loader's + // unknown-tensor branch handles that via the per-component tensor map). + if (lazy_load_fn) return true; + return do_alloc_params_buffer(); + } + void free_params_buffer() { + // When offload_params_to_cpu is in effect, the tensors currently point + // at `runtime_params_buffer` (on the runtime backend). Restore them to + // `params_buffer` and free the runtime copy before freeing the params + // buffer itself — otherwise subsequent offloads would re-copy freed + // memory and the runtime buffer would leak on teardown. + offload_params_to_params_backend(); if (params_buffer != nullptr) { ggml_backend_buffer_free(params_buffer); params_buffer = nullptr; } + if (split_params_buffer != nullptr) { + ggml_backend_buffer_free(split_params_buffer); + split_params_buffer = nullptr; + } + // Drop sched too — its compute buffer is bound to the now-freed params + // backend(s). Next compute() will re-create it from a freshly allocated + // (lazy-loaded) params buffer. + if (sched != nullptr) { + ggml_backend_sched_free(sched); + sched = nullptr; + sched_reserved = false; + } + } + + // Register a callback that lazy-loads this runner's weights into ggml + // tensors AFTER alloc_params_buffer succeeds. ensure_params_loaded() + // invokes alloc + this callback on the first compute() (or after a + // previous free_params_buffer). Used by the sequential-loading path in + // stable-diffusion.cpp where DiT and Conditioner lazy-load their weights + // from disk so peak VRAM is per-phase rather than sum-of-components. + void set_lazy_load(std::function fn) { + lazy_load_fn = std::move(fn); + } + + bool ensure_params_loaded() { + // Already alloc'd (eager mode or previously lazy-loaded and not yet freed). + if (params_buffer != nullptr || split_params_buffer != nullptr) { + return true; + } + if (!lazy_load_fn) { + // Eager mode but caller forgot to alloc — surface an error rather + // than crashing later in graph alloc. + LOG_ERROR("%s: no params buffer and no lazy_load_fn — caller must alloc_params_buffer", get_desc().c_str()); + return false; + } + int64_t t0 = ggml_time_ms(); + // Use do_alloc_params_buffer (which bypasses the lazy-skip in the + // public alloc_params_buffer wrapper). Otherwise the buffer would + // remain unallocated and lazy_load_fn would write into null tensor + // data pointers and the disk read would fail. + if (!do_alloc_params_buffer()) { + LOG_ERROR("%s: lazy alloc_params_buffer failed", get_desc().c_str()); + return false; + } + if (!lazy_load_fn()) { + LOG_ERROR("%s: lazy load callback failed", get_desc().c_str()); + return false; + } + int64_t t1 = ggml_time_ms(); + LOG_INFO("%s: lazy-loaded params in %.2fs", get_desc().c_str(), (t1 - t0) / 1000.f); + return true; } size_t get_params_buffer_size() { + size_t total = 0; if (params_buffer != nullptr) { - return ggml_backend_buffer_get_size(params_buffer); + total += ggml_backend_buffer_get_size(params_buffer); + } + if (split_params_buffer != nullptr) { + total += ggml_backend_buffer_get_size(split_params_buffer); } - return 0; + return total; } void free_cache_ctx_and_buffer() { @@ -2075,11 +2464,32 @@ struct GGMLRunner { ggml_gallocr_free(compute_allocr); compute_allocr = nullptr; } - offload_params_to_params_backend(); + if (sched != nullptr) { + // Reset rather than free: keeping the sched alive across compute() + // calls of a sampling loop avoids the per-step rebuild cost. + // free_params_immediately tears the runner down entirely (dtor), + // which frees sched fully — see ~GGMLRunner. + ggml_backend_sched_reset(sched); + sched_reserved = false; + } + // Intentionally do NOT call offload_params_to_params_backend() here. + // For offload mode, keeping runtime_params_buffer resident across + // compute() calls of the same runner is the whole point — otherwise + // a DiT sampling loop re-uploads ~9 GB from RAM to GPU every CFG pass + // (observed ~5.7 min of pure upload overhead on 60-step 720p runs). + // free_params_buffer() handles the offload teardown when the caller + // is actually done with the runner (sd.cpp triggers it via + // free_params_immediately between components). } // do copy after alloc graph void set_backend_tensor_data(ggml_tensor* tensor, const void* data) { + // Mark as INPUT so ggml_backend_sched assigns the tensor a backend + // (last-prio CPU). Without this, tensors with no producers and no + // GPU consumers leave sched at backend_id=-1, tripping + // ggml_gallocr_allocate_node's GGML_ASSERT(buffer_id >= 0). Harmless + // for the gallocr (single-backend) path. + ggml_set_input(tensor); backend_tensor_data_map[tensor] = data; } @@ -2139,6 +2549,11 @@ struct GGMLRunner { int n_threads, bool free_compute_buffer_immediately, bool no_return = false) { + // Lazy load: if this runner was registered with a lazy_load_fn, + // alloc_params_buffer + read weights from disk now (or after a free). + if (!ensure_params_loaded()) { + return std::nullopt; + } if (!offload_params_to_runtime_backend()) { LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str()); return std::nullopt; @@ -2147,6 +2562,41 @@ struct GGMLRunner { LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str()); return std::nullopt; } + + if (multi_backend_mode) { + // Sched path: reset, build a fresh graph, alloc through the sched, + // populate inputs, compute. The CUDA backend handles cross-device + // dispatch for split-tensor matmuls internally. + ggml_backend_sched_reset(sched); + reset_compute_ctx(); + ggml_cgraph* gf = get_compute_graph(get_graph); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LOG_ERROR("%s sched alloc graph failed", get_desc().c_str()); + return std::nullopt; + } + copy_data_to_backend_tensor(); + if (cpu_fallback_backend && ggml_backend_is_cpu(cpu_fallback_backend)) { + ggml_backend_cpu_set_n_threads(cpu_fallback_backend, n_threads); + } + ggml_status status = ggml_backend_sched_graph_compute(sched, gf); + if (status != GGML_STATUS_SUCCESS) { + LOG_ERROR("%s sched compute failed: %s", + get_desc().c_str(), ggml_status_to_string(status)); + return std::nullopt; + } + ggml_backend_sched_synchronize(sched); + copy_cache_tensors_to_cache_buffer(); + auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str()); + std::optional> output; + if (!no_return) { + output = sd::make_sd_tensor_from_ggml(result); + } + if (free_compute_buffer_immediately) { + free_compute_buffer(); + } + return output; + } + reset_compute_ctx(); ggml_cgraph* gf = get_compute_graph(get_graph); if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) { @@ -2353,12 +2803,67 @@ class Linear : public UnaryBlock { } }; +// Set by GGMLRunner's constructor to the params backend of the most recently +// constructed runner. Read by support_get_rows() below. Defined as a +// function-local static so the header stays single-definition. +__STATIC_INLINE__ ggml_backend_t& current_params_backend() { + static ggml_backend_t b = nullptr; + return b; +} + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { - std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; - if (allow_types.find(wtype) != allow_types.end()) { - return true; - } - return false; + // ggml-cpu implements get_rows for the full quant set in + // ggml_compute_forward_get_rows (ggml-cpu/ops.cpp) — both the legacy + // Q{4,5,8}_{0,1} formats AND the K-quants / IQ-quants. The CUDA kernel + // in ggml-cuda/getrows.cu only supports F16/BF16/F32/I32 + legacy + // Q4_{0,1}/Q5_{0,1}/Q8_0 — calling it with a K- or IQ-quant aborts. + // So the allowlist must match the BACKEND that will actually hold the + // Embedding weight. When the current runner's params backend is CUDA + // we fall back to F32 for non-legacy quants (costs ~3.5 GB VRAM for a + // Gemma 12B IQ4_XS token_embd, but is the only option that runs). + static const std::set allow_legacy = { + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q5_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_0, + }; + static const std::set allow_full = { + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q5_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_0, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_IQ2_XXS, + GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + GGML_TYPE_IQ1_S, + GGML_TYPE_IQ1_M, + GGML_TYPE_IQ4_NL, + GGML_TYPE_IQ4_XS, + }; + + ggml_backend_t b = current_params_backend(); + const bool on_cpu = + (b == nullptr) || ggml_backend_is_cpu(b); + // Debug knob: set SD_FORCE_LEGACY_GETROWS=1 to apply the CUDA-safe + // allowlist on both CPU and CUDA. Useful for isolating whether CPU and + // CUDA agree on the F32 fallback path. + static const bool force_legacy = + std::getenv("SD_FORCE_LEGACY_GETROWS") != nullptr; + const auto& allow = (on_cpu && !force_legacy) ? allow_full : allow_legacy; + return allow.count(wtype) > 0; } class Embedding : public UnaryBlock { diff --git a/src/llm.hpp b/src/llm.hpp index 4afaa3ba6..d67c2cf9d 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -21,14 +21,40 @@ #include "tokenizers/mistral_tokenizer.h" #include "tokenizers/qwen2_tokenizer.h" +// Debug tap: when non-null, Gemma layer-0 forward paths push intermediate +// tensors here (tagged "DBG:"). Definition lives as `inline` to keep +// this file header-only. Set from LLMRunner::compute_all_hidden_states when +// the SD_DUMP_LAYER0 env var is present. +inline std::vector* g_layer0_taps = nullptr; + +// Helper: preserve a tap's value by routing the graph THROUGH a ggml_cont +// copy. Returning the cont'd tensor (instead of the original) means the +// next op in the graph consumes the cont, so the allocator has to keep the +// cont's buffer live. Mathematically a bitwise copy — no graph change. +// The cont's name starts with "DBG:" so the dumper can find it. +inline ggml_tensor* tap_tensor(ggml_context* ctx, ggml_tensor* t, const char* name) { + if (::g_layer0_taps == nullptr) return t; + ggml_tensor* keep = ggml_cont(ctx, t); + ggml_set_output(keep); // tell allocator: don't reuse my buffer + ggml_set_name(keep, (std::string("DBG:") + name).c_str()); + ::g_layer0_taps->push_back(keep); + return keep; +} + namespace LLM { - constexpr int LLM_GRAPH_SIZE = 10240; + // Bumped aggressively for the 22B LTX-2 smoke test where Gemma 3 12B runs with + // compute_all_hidden_states (49-layer concat stack over 48 layers of sandwich- + // norm + attn + MLP). The assert at ggml.c:6877 fired at 40960; 200000 leaves + // ample headroom while we diagnose whether real op count or hash dedup is the + // issue. + constexpr int LLM_GRAPH_SIZE = 200000; enum class LLMArch { QWEN2_5_VL, QWEN3, MISTRAL_SMALL_3_2, MINISTRAL_3_3B, + GEMMA3, ARCH_COUNT, }; @@ -37,6 +63,7 @@ namespace LLM { "qwen3", "mistral_small3.2", "ministral3.3b", + "gemma3", }; struct LLMVisionParams { @@ -65,15 +92,85 @@ namespace LLM { bool qk_norm = false; int64_t vocab_size = 152064; float rms_norm_eps = 1e-06f; + + // Gemma 3 additions (unused by other archs). + // Pattern: layers where (idx % sliding_window_pattern == 0) use global attention + // with rope_theta_global; other layers use sliding-window attention of size + // sliding_window with rope_theta_local. has_post_norms adds a second RMSNorm after + // attn and after MLP inside each block. embed_scale multiplies token embeddings + // once before the first layer. + int sliding_window = 0; // 0 = disabled + int sliding_window_pattern = 0; // 0 = disabled + float rope_theta_global = 0.f; // 0 = use legacy hardcoded theta + float rope_theta_local = 0.f; + // Gemma 3 rope_scaling: linear RoPE scaling applied only to full-attention + // (global) layers. HuggingFace config.json: rope_scaling={factor: F, rope_type: linear}. + // Sliding layers are unscaled. 1.0 = disabled. For the 12B model this is 8.0. + float rope_scaling_factor_global = 1.0f; + bool has_post_norms = false; + float embed_scale = 1.0f; + + // When true, Linear layers inside this model force GGML_PREC_F32 on + // their mul_mat ops. ggml-cuda defaults to F16 accumulation for + // quantized matmul, which drifts ~2% per layer vs the CPU/F32 path. + // For Gemma 3 used as a fixed embedding encoder (LTX-2) the compound + // drift across 48 layers corrupts the final embedding to uselessness + // on CUDA. Set true for Gemma 3; leave false for generative LLMs + // where the drift is acceptable and speed matters more. + bool force_matmul_prec_f32 = false; + LLMVisionParams vision; }; + // Gemma 3 RMSNorm variant: scale by (1 + w) instead of w. The PyTorch original + // Gemma3RMSNorm stores weights centered around 0 (so init scale is 1.0), and the + // forward applies `x * (1 + w)`. This class implements that math. + // + // IMPORTANT: this is currently UNUSED for production Gemma3, because our only + // supported Gemma3 source is GGUF. llama.cpp's `convert_hf_to_gguf.py` + // (`Gemma3Model.norm_shift`) bakes the +1 INTO the weights at convert time + // (so `w_gguf = w_pytorch + 1`), letting llama.cpp's runtime use the simpler + // `x * w` form. We therefore consume those GGUF weights with plain `RMSNorm` + // and the +1 is implicit in the weight values. If a non-GGUF Gemma3 loader + // is ever added, swap to this class for those code paths. + class RMSNormPlus1 : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + std::string prefix; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + this->prefix = prefix; + params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + + public: + RMSNormPlus1(int64_t hidden_size, float eps = 1e-06f) + : hidden_size(hidden_size), eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + ggml_tensor* w = params["weight"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + auto scaled = ggml_mul(ctx->ggml_ctx, x, w); // rms(x) * w + x = ggml_add_inplace(ctx->ggml_ctx, x, scaled); // rms(x) * (1 + w) + return x; + } + }; + struct MLP : public GGMLBlock { + protected: + bool use_gelu_tanh; + public: - MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { - blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); - blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); - blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); + MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false, + bool use_gelu_tanh = false, bool force_prec_f32 = false) + : use_gelu_tanh(use_gelu_tanh) { + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias, /*force_f32=*/false, force_prec_f32)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias, /*force_f32=*/false, force_prec_f32)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias, /*force_f32=*/false, force_prec_f32)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { @@ -83,9 +180,13 @@ namespace LLM { auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); auto h = gate_proj->forward(ctx, x); - h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); - h = down_proj->forward(ctx, h); + if (use_gelu_tanh) { + h = ggml_gelu_inplace(ctx->ggml_ctx, h); + } else { + h = ggml_silu_inplace(ctx->ggml_ctx, h); + } + h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); + h = down_proj->forward(ctx, h); return h; } }; @@ -376,24 +477,49 @@ namespace LLM { int64_t num_heads; int64_t num_kv_heads; bool qk_norm; + int layer_idx; + int sliding_window_pattern; + float rope_theta_global; + float rope_theta_local; + float rope_scaling_factor_global; public: - Attention(const LLMParams& params) - : arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) { - blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); - blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); - blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false); + Attention(const LLMParams& params, int layer_idx = 0) + : arch(params.arch), + num_heads(params.num_heads), + num_kv_heads(params.num_kv_heads), + head_dim(params.head_dim), + qk_norm(params.qk_norm), + layer_idx(layer_idx), + sliding_window_pattern(params.sliding_window_pattern), + rope_theta_global(params.rope_theta_global), + rope_theta_local(params.rope_theta_local), + rope_scaling_factor_global(params.rope_scaling_factor_global) { + const bool fp = params.force_matmul_prec_f32; + blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias, /*force_f32=*/false, fp); + blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false, /*force_f32=*/false, fp); if (params.qk_norm) { + // Gemma3 GGUF: weights have +1 baked in (see `RMSNormPlus1` comment), + // so plain `RMSNorm` produces `x * w_gguf == x * (w_pytorch + 1)` which + // matches the PyTorch reference's `x * (1 + w_pytorch)`. blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); } } + bool is_gemma_sliding_layer() const { + return arch == LLMArch::GEMMA3 + && sliding_window_pattern > 0 + && ((layer_idx + 1) % sliding_window_pattern) != 0; + } + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + ggml_tensor* attention_mask_sliding = nullptr) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -402,20 +528,25 @@ namespace LLM { auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); auto out_proj = std::dynamic_pointer_cast(blocks["o_proj"]); - auto q = q_proj->forward(ctx, x); // [N, n_token, num_heads*head_dim] - auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] - auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] + const bool trace = (layer_idx == 0); + auto tag = [&](ggml_tensor* t, const char* name) { + return trace ? tap_tensor(ctx->ggml_ctx, t, name) : t; + }; + + auto q = tag(q_proj->forward(ctx, x), "q_proj"); + auto k = tag(k_proj->forward(ctx, x), "k_proj"); + auto v = tag(v_proj->forward(ctx, x), "v_proj"); q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim] k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] if (qk_norm) { - auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); - auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); - q = q_norm->forward(ctx, q); - k = k_norm->forward(ctx, k); + q = tag(q_norm->forward(ctx, q), "q_norm"); + k = tag(k_norm->forward(ctx, k), "k_norm"); } if (arch == LLMArch::MISTRAL_SMALL_3_2) { @@ -427,41 +558,106 @@ namespace LLM { } else if (arch == LLMArch::QWEN3) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::GEMMA3) { + // Per-layer theta: global (full attention) layers use rope_theta_global, + // sliding layers use rope_theta_local. Pattern: is_global = ((l+1)%p == 0). + // Real Gemma 3 12B config also sets linear rope_scaling with factor=8.0 + // on full_attention only. HuggingFace divides inv_freq by factor, which + // ggml_rope_ext expresses as freq_scale = 1 / factor. + bool is_sliding = is_gemma_sliding_layer(); + float theta = is_sliding ? rope_theta_local : rope_theta_global; + float freq_scale = is_sliding ? 1.0f : (1.0f / rope_scaling_factor_global); + q = tag(ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f), "q_rope"); + k = tag(ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 1024, theta, freq_scale, 0.f, 1.f, 32.f, 1.f), "k_rope"); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_multi(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } + // Gemma 3: pick the sliding-window mask for local layers. + if (is_gemma_sliding_layer() && attention_mask_sliding != nullptr) { + attention_mask = attention_mask_sliding; + } + q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false); // [N, n_token, hidden_size] + x = tag(ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, false), "attn_out"); // [N, n_token, hidden_size] - x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] + x = tag(out_proj->forward(ctx, x), "o_proj"); // [N, n_token, hidden_size] return x; } }; struct TransformerBlock : public GGMLBlock { + protected: + bool has_post_norms; + int layer_idx; + public: - TransformerBlock(const LLMParams& params) { - blocks["self_attn"] = std::make_shared(params); - blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size); - blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); - blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + TransformerBlock(const LLMParams& params, int layer_idx = 0) + : has_post_norms(params.has_post_norms), layer_idx(layer_idx) { + bool gemma = (params.arch == LLMArch::GEMMA3); + blocks["self_attn"] = std::make_shared(params, layer_idx); + blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size, false, gemma, params.force_matmul_prec_f32); + + if (gemma) { + // GGUF Gemma3: weights have +1 baked in by llama.cpp's convert script, + // so plain `RMSNorm` is the right form. See `RMSNormPlus1` class comment. + blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["pre_feedforward_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_feedforward_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + } else { + blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + } } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + ggml_tensor* attention_mask_sliding = nullptr) { // x: [N, n_token, hidden_size] - auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); - auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + if (has_post_norms) { + // Gemma 3 sandwich: pre-attn-norm → attn → post-attn-norm → +res + // → pre-ff-norm → mlp → post-ff-norm → +res. + auto input_ln = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attn_ln = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + auto pre_ff_ln = std::dynamic_pointer_cast(blocks["pre_feedforward_layernorm"]); + auto post_ff_ln = std::dynamic_pointer_cast(blocks["post_feedforward_layernorm"]); + + auto residual = x; + const bool trace_block = (layer_idx == 0); + auto tag = [&](ggml_tensor* t, const char* name) { + return trace_block ? tap_tensor(ctx->ggml_ctx, t, name) : t; + }; + if (trace_block) { + x = tag(x, "x_embed_in"); + residual = x; // residual must match the post-tap tensor + } + + x = tag(input_ln->forward(ctx, x), "input_ln"); + x = self_attn->forward(ctx, x, input_pos, attention_mask, attention_mask_sliding); + x = tag(post_attn_ln->forward(ctx, x), "post_attn_ln"); + x = tag(ggml_add_inplace(ctx->ggml_ctx, x, residual), "after_attn_res"); + + residual = x; + x = tag(pre_ff_ln->forward(ctx, x), "pre_ff_ln"); + x = tag(mlp->forward(ctx, x), "mlp_out"); + x = tag(post_ff_ln->forward(ctx, x), "post_ff_ln"); + x = tag(ggml_add_inplace(ctx->ggml_ctx, x, residual), "after_ff_res"); + return x; + } + auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); @@ -482,14 +678,20 @@ namespace LLM { struct TextModel : public GGMLBlock { protected: int64_t num_layers; + float embed_scale; + bool has_post_norms; public: TextModel(const LLMParams& params) - : num_layers(params.num_layers) { + : num_layers(params.num_layers), + embed_scale(params.embed_scale), + has_post_norms(params.has_post_norms) { blocks["embed_tokens"] = std::shared_ptr(new Embedding(params.vocab_size, params.hidden_size)); for (int i = 0; i < num_layers; i++) { - blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params, i)); } + // GGUF Gemma3 norm weights have +1 baked in (per llama.cpp convert), so plain + // RMSNorm is correct for both Gemma3 and other archs. blocks["norm"] = std::shared_ptr(new RMSNorm(params.hidden_size, params.rms_norm_eps)); } @@ -498,14 +700,24 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { // input_ids: [N, n_token] // return: [N, n_token, hidden_size] auto embed_tokens = std::dynamic_pointer_cast(blocks["embed_tokens"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto x = embed_tokens->forward(ctx, input_ids); + x = tap_tensor(ctx->ggml_ctx, x, "embed_raw"); + if (embed_scale != 1.0f) { + x = ggml_scale(ctx->ggml_ctx, x, embed_scale); + x = tap_tensor(ctx->ggml_ctx, x, "embed_scaled"); + } + if (all_hidden_states) { + all_hidden_states->push_back(x); + } std::vector intermediate_outputs; @@ -551,7 +763,10 @@ namespace LLM { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos, attention_mask); + x = block->forward(ctx, x, input_pos, attention_mask, attention_mask_sliding); + if (all_hidden_states) { + all_hidden_states->push_back(x); + } if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -565,6 +780,12 @@ namespace LLM { } else { x = norm->forward(ctx, x); } + // HF Gemma 3 (and most HF causal-LM models): hidden_states[-1] is the + // POST-final-norm state. Replace the last pre-norm entry we stored with + // the normed version so downstream stacking matches exactly. + if (all_hidden_states && !all_hidden_states->empty()) { + all_hidden_states->back() = x; + } return x; } }; @@ -599,11 +820,14 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); return x; } @@ -652,13 +876,47 @@ namespace LLM { params.qkv_bias = false; params.qk_norm = true; params.rms_norm_eps = 1e-6f; + } else if (arch == LLMArch::GEMMA3) { + // Gemma 3 12B (LTX-2 text encoder). See memory file + // .opencode/memories/2026-04-22_1000_gemma3-delta-note.md for derivation. + params.head_dim = 256; + params.num_heads = 16; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.qk_norm = true; + params.rms_norm_eps = 1e-6f; + params.sliding_window = 1024; + params.sliding_window_pattern = 6; + params.rope_theta_global = 1000000.f; + params.rope_theta_local = 10000.f; + // Real Gemma 3 12B config.json sets rope_scaling={factor: 8.0, + // rope_type: linear} on full_attention layers. HuggingFace divides + // inv_freq by factor, which corresponds to ggml_rope_ext freq_scale + // = 1/factor. Sliding-attention layers stay unscaled. + params.rope_scaling_factor_global = 8.f; + params.has_post_norms = true; + // Gemma 3 has narrow weight scales; the CUDA mmvq/mmq kernels + // quantize activations to q8_1 (block-32 fp16 scale) while the + // CPU iq4_xs kernel uses q8_K (block-256 fp32 scale). That + // format mismatch causes ~5% per-layer drift and ruins the + // embedding. Requesting GGML_PREC_F32 routes matmul through + // cuBLAS dequant+GEMM, which matches CPU bit-for-bit. Even with + // Q8_0 weights this disables TF32 to keep prompt fidelity — + // without it the cumulative reduction-order drift across 48 + // layers shifts subject identity (cat → person on beach). + params.force_matmul_prec_f32 = true; + // embed_scale is sqrt(hidden_size); hidden_size is autodetected below, + // so defer setting embed_scale until after the tensor-storage scan. } bool have_vision_weight = false; bool llama_cpp_style = false; params.num_layers = 0; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; - if (tensor_name.find(prefix) == std::string::npos) + // Use prefix-boundary match (must be followed by '.') rather than bare + // substring: otherwise e.g. prefix "text_encoder" would also match + // "text_encoder_deep.*" tensors and inflate auto-detected num_layers. + if (tensor_name.rfind(prefix + ".", 0) != 0) continue; size_t pos = tensor_name.find("visual."); if (pos != std::string::npos) { @@ -686,10 +944,32 @@ namespace LLM { if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) { params.intermediate_size = pair.second.ne[1]; } + if (arch == LLMArch::GEMMA3) { + // Gemma 3 has configurable head_dim (256 for 12B, 32 in our tiny test). + // q_norm.weight has shape [head_dim]; q_proj.weight is [hidden_size, num_heads*head_dim] + // and stored in GGML with ne[1]=num_heads*head_dim; likewise k_proj gives num_kv_heads. + if (contains(tensor_name, "layers.0.self_attn.q_norm.weight")) { + params.head_dim = (int)pair.second.ne[0]; + } + } } if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B params.num_heads = 16; } + if (arch == LLMArch::GEMMA3) { + // Second pass: derive num_heads / num_kv_heads once head_dim is known. + for (auto pair : tensor_storage_map) { + std::string tn = pair.first; + if (tn.rfind(prefix + ".", 0) != 0) continue; + if (contains(tn, "layers.0.self_attn.q_proj.weight") && params.head_dim > 0) { + params.num_heads = (int)(pair.second.ne[1] / params.head_dim); + } + if (contains(tn, "layers.0.self_attn.k_proj.weight") && params.head_dim > 0) { + params.num_kv_heads = (int)(pair.second.ne[1] / params.head_dim); + } + } + params.embed_scale = sqrtf((float)params.hidden_size); + } LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, params.num_layers, params.vocab_size, @@ -722,8 +1002,11 @@ namespace LLM { ggml_tensor* input_pos, ggml_tensor* attention_mask, std::vector> image_embeds, - std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] + std::set out_layers, + ggml_tensor* attention_mask_sliding = nullptr, + std::vector* all_hidden_states = nullptr) { + auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); // [N, n_token, hidden_size] return hidden_states; } @@ -737,11 +1020,16 @@ namespace LLM { return hidden_states; } + // Scratch storage for the Gemma sliding-window mask. + std::vector sliding_attention_mask_vec; + ggml_cgraph* build_graph(const sd::Tensor& input_ids_tensor, const sd::Tensor& attention_mask_tensor, const std::vector>>& image_embeds_tensor, - std::set out_layers) { - ggml_cgraph* gf = ggml_new_graph(compute_ctx); + std::set out_layers, + std::vector* all_hidden_states = nullptr, + int pad_count = 0) { + ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE); ggml_tensor* input_ids = make_input(input_ids_tensor); std::vector> image_embeds; image_embeds.reserve(image_embeds_tensor.size()); @@ -751,7 +1039,7 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3 || params.arch == LLMArch::GEMMA3) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -775,13 +1063,20 @@ namespace LLM { if (!attention_mask_tensor.empty()) { attention_mask = make_input(attention_mask_tensor); } else { + // Causal AND (when pad_count > 0) "real query cannot attend to a pad key" + // — required when the input is left-padded (pad tokens occupy [0, pad_count)). + // Pad-as-query rows still attend causally to earlier pads so softmax stays + // finite; pad outputs are discarded downstream. attention_mask_vec.resize(n_tokens * n_tokens); - for (int i0 = 0; i0 < n_tokens; i0++) { - for (int i1 = 0; i1 < n_tokens; i1++) { + for (int i0 = 0; i0 < n_tokens; i0++) { // i0 = key + for (int i1 = 0; i1 < n_tokens; i1++) { // i1 = query float value = 0.f; if (i0 > i1) { value = -INFINITY; } + if (pad_count > 0 && i0 < pad_count && i1 >= pad_count) { + value = -INFINITY; + } attention_mask_vec[i1 * n_tokens + i0] = value; } } @@ -789,9 +1084,32 @@ namespace LLM { set_backend_tensor_data(attention_mask, attention_mask_vec.data()); } + // Gemma 3 sliding-window mask: causal AND (q - k < window_size), with the + // same pad-as-key restriction applied so real queries don't pick up pad keys + // even within the sliding window. + ggml_tensor* attention_mask_sliding = nullptr; + if (params.arch == LLMArch::GEMMA3 && params.sliding_window > 0) { + sliding_attention_mask_vec.resize(n_tokens * n_tokens); + for (int q = 0; q < n_tokens; q++) { + for (int k = 0; k < n_tokens; k++) { + float value = 0.f; + if (k > q || (q - k) >= params.sliding_window) { + value = -INFINITY; + } + if (pad_count > 0 && k < pad_count && q >= pad_count) { + value = -INFINITY; + } + sliding_attention_mask_vec[q * n_tokens + k] = value; + } + } + attention_mask_sliding = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens); + set_backend_tensor_data(attention_mask_sliding, sliding_attention_mask_vec.data()); + } + auto runner_ctx = get_context(); - ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers, + attention_mask_sliding, all_hidden_states); ggml_build_forward_expand(gf, hidden_states); @@ -809,6 +1127,86 @@ namespace LLM { return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); } + // Returns all N+1 hidden states (embedding + each transformer layer, with + // final layer's output post-model.norm). Stacked along a new innermost axis, + // shape in GGML: ne=[num_layers+1, hidden_size, n_tokens, batch] which matches + // PyTorch `torch.stack(hidden_states, dim=-1)` layout of [B, T, H, N+1]. + sd::Tensor compute_all_hidden_states(const int n_threads, + const sd::Tensor& input_ids, + const sd::Tensor& attention_mask, + int pad_count = 0) { + // Debug hook: capture layer-0 intermediates via the global tap vector. + // Forward paths push tensors here when ::g_layer0_taps != nullptr. + std::vector taps; + const char* dump_dir = std::getenv("SD_DUMP_LAYER0"); + if (dump_dir != nullptr) ::g_layer0_taps = &taps; + struct TapGuard { + ~TapGuard() { ::g_layer0_taps = nullptr; } + } guard; + + auto get_graph = [&]() -> ggml_cgraph* { + // GGMLRunner::compute calls this lambda TWICE — once to measure + // the allocator, then reset_compute_ctx wipes all tensors and + // it's called again to actually compute. Both builds need to + // re-fire attn_tap so the second build's tensors get our names + // (otherwise ggml auto-names them "node_X" since the first + // build's set_name applied to since-dead tensors). + ::g_attn_tap_count = 0; + taps.clear(); // also clear the OUTER taps so we only collect + // pointers from the latest (compute-pass) build + std::vector hidden_states; + ggml_cgraph* gf = build_graph(input_ids, attention_mask, {}, {}, &hidden_states, pad_count); + + // Keep taps alive through the allocator: mark each as an output + // (prevents buffer aliasing) and expand into the graph. + for (auto* t : taps) { + ggml_set_output(t); + ggml_build_forward_expand(gf, t); + } + + GGML_ASSERT(!hidden_states.empty()); + // Reshape each [H, T, B] -> [1, H, T, B] so we can concat along axis 0. + ggml_tensor* stacked = nullptr; + for (auto* h : hidden_states) { + auto h_cont = ggml_cont(compute_ctx, h); + auto h_4d = ggml_reshape_4d(compute_ctx, h_cont, 1, h_cont->ne[0], h_cont->ne[1], h_cont->ne[2]); + if (stacked == nullptr) { + stacked = h_4d; + } else { + stacked = ggml_concat(compute_ctx, stacked, h_4d, 0); + } + } + ggml_build_forward_expand(gf, stacked); + return gf; + }; + auto result = take_or_empty(GGMLRunner::compute(get_graph, n_threads, /*free_compute_buffer_immediately=*/false)); + + if (dump_dir != nullptr && !taps.empty()) { + LOG_INFO("SD_DUMP_LAYER0: dumping %zu tensors to %s/", taps.size(), dump_dir); + for (auto* t : taps) { + const char* full_name = ggml_get_name(t); + if (std::strncmp(full_name, "DBG:", 4) != 0) continue; + const char* name = full_name + 4; + size_t nbytes = ggml_nbytes(t); + std::vector buf(nbytes); + ggml_backend_tensor_get(t, buf.data(), 0, nbytes); + std::string path = std::string(dump_dir) + "/" + name + ".bin"; + FILE* f = std::fopen(path.c_str(), "wb"); + if (f) { + std::fwrite(buf.data(), 1, nbytes, f); + std::fclose(f); + LOG_INFO(" %-22s ne=[%ld,%ld,%ld,%ld] type=%s bytes=%zu -> %s", + name, (long)t->ne[0], (long)t->ne[1], (long)t->ne[2], (long)t->ne[3], + ggml_type_name(t->type), nbytes, path.c_str()); + } + } + // Free now so we don't leak the compute buffer. + free_compute_buffer(); + } + + return result; + } + int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { int64_t grid_t = 1; int64_t grid_h = h / params.vision.patch_size; @@ -989,6 +1387,10 @@ namespace LLM { : model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { tokenizer = std::make_shared(); + } else if (arch == LLMArch::GEMMA3) { + // Gemma 3 uses SentencePiece (vocab 262208). A SentencePiece loader is + // not yet implemented in this repo; tokenization path lands in task #25. + GGML_ABORT("Gemma 3 SentencePiece tokenizer not implemented yet"); } else { tokenizer = std::make_shared(); } diff --git a/src/ltx.hpp b/src/ltx.hpp new file mode 100644 index 000000000..7dbcf4cc4 --- /dev/null +++ b/src/ltx.hpp @@ -0,0 +1,1273 @@ +#ifndef __LTX_HPP__ +#define __LTX_HPP__ + +#include +#include +#include +#include + +#include "ggml_extend.hpp" +#include "ltx_rope.hpp" +#include "model.h" + +// LTX-2 video DiT. +// Reference: /devel/tools/diffusion/LTX-2/packages/ltx-core/src/ltx_core/model/transformer/ +// +// Scope (first landing): text-conditioned video-only (LTXModelType.VideoOnly), rope_type=INTERLEAVED, +// cross_attention_adaln=false, apply_gated_attention=false. Audio pathway and AV cross-attention are +// deferred (stubbed out) — the weights are just not instantiated. + +namespace LTX { + // 32768 was enough for the 2-layer parity-test DiT. The 22B V2 has 48 layers + // + cross_attention_adaln + prompt_adaln_single, roughly 2-3× the op count + // per block vs. V1. Bump generously so graph construction never fails the + // `cgraph->n_nodes < cgraph->size` assert in ggml's append path. + constexpr int LTX_GRAPH_SIZE = 131072; + constexpr int TIME_PROJ_DIM = 256; + constexpr int ADALN_BASE = 6; + constexpr int ADALN_WITH_CA = 9; + + // Python: ltx_core.model.transformer.rope.LTXRopeType. Real LTX-2.3 config uses + // SPLIT; earlier LTX variants (and our parity test's old default) were INTERLEAVED. + enum class RopeType { INTERLEAVED, SPLIT }; + + // Parameter-free RMSNorm helper. + __STATIC_INLINE__ ggml_tensor* parameterless_rms_norm(ggml_context* ctx, ggml_tensor* x, float eps = 1e-6f) { + return ggml_rms_norm(ctx, x, eps); + } + + struct AdaLayerNormSingle : public GGMLBlock { + protected: + int embedding_dim; + int embedding_coefficient; + + public: + AdaLayerNormSingle() = default; + AdaLayerNormSingle(int embedding_dim, int embedding_coefficient) + : embedding_dim(embedding_dim), embedding_coefficient(embedding_coefficient) { + // Python: self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim, size_emb_dim=embedding_dim // 3) + // -> time_proj: sinusoidal (no weights) + // -> timestep_embedder.linear_1: Linear(256, embedding_dim) + // -> timestep_embedder.linear_2: Linear(embedding_dim, embedding_dim) + // Python: self.linear = Linear(embedding_dim, coefficient * embedding_dim) + blocks["emb.timestep_embedder.linear_1"] = std::make_shared(TIME_PROJ_DIM, embedding_dim, true); + blocks["emb.timestep_embedder.linear_2"] = std::make_shared(embedding_dim, embedding_dim, true); + blocks["linear"] = std::make_shared(embedding_dim, embedding_coefficient * embedding_dim, true); + } + + // timestep: [B] — caller MUST pass the pre-scaled timestep (σ * timestep_scale_multiplier). + // Python applies the scaling in TransformerArgsPreprocessor._prepare_timestep; we mirror that + // boundary so the denoiser (sigma_to_t) is the single place that owns the 1000× factor. + // Double-scaling (denoiser + AdaLN) would drive sinusoidal embedding args to σ·1e6, which is + // numerical nonsense and was a real risk before this refactor. + // + // Returns {modulation, embedded_timestep}. + // modulation ne: [embedding_dim, coefficient, B] + // embedded_timestep ne: [embedding_dim, B] + std::pair forward(GGMLRunnerContext* ctx, + ggml_tensor* timestep) { + auto l1 = std::dynamic_pointer_cast(blocks["emb.timestep_embedder.linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["emb.timestep_embedder.linear_2"]); + auto proj = std::dynamic_pointer_cast(blocks["linear"]); + + auto t_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, TIME_PROJ_DIM, 10000, 1.0f); + auto hidden = l1->forward(ctx, t_proj); + hidden = ggml_silu_inplace(ctx->ggml_ctx, hidden); + auto embedded = l2->forward(ctx, hidden); // [embedding_dim, B] + + auto modulation = ggml_silu(ctx->ggml_ctx, embedded); + modulation = proj->forward(ctx, modulation); // [coeff*embedding_dim, B] + + int64_t B = modulation->ne[1]; + modulation = ggml_reshape_3d(ctx->ggml_ctx, modulation, embedding_dim, embedding_coefficient, B); + return {modulation, embedded}; + } + }; + + // GELUApprox block: Linear(dim_in → dim_out) + gelu(tanh approximation). + // Python: GELUApprox uses torch.nn.functional.gelu(..., approximate="tanh") which matches ggml_gelu. + struct GELUApprox : public GGMLBlock { + public: + GELUApprox() = default; + GELUApprox(int64_t dim_in, int64_t dim_out) { + blocks["proj"] = std::make_shared(dim_in, dim_out, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + x = proj->forward(ctx, x); + return ggml_ext_gelu(ctx->ggml_ctx, x, true); + } + }; + + struct FeedForward : public GGMLBlock { + public: + FeedForward() = default; + FeedForward(int64_t dim, int64_t dim_out, int mult = 4) { + int64_t inner = dim * mult; + // Python: self.net = Sequential(GELUApprox(dim, inner), Identity(), Linear(inner, dim_out)) + blocks["net.0"] = std::make_shared(dim, inner); + blocks["net.2"] = std::make_shared(inner, dim_out, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto gelu_proj = std::dynamic_pointer_cast(blocks["net.0"]); + auto out_proj = std::dynamic_pointer_cast(blocks["net.2"]); + x = gelu_proj->forward(ctx, x); + x = out_proj->forward(ctx, x); + return x; + } + }; + + struct LTXAttention : public GGMLBlock { + protected: + int64_t query_dim; + int64_t context_dim; + int num_heads; + int head_dim; + int64_t inner_dim; + float norm_eps; + bool apply_gated_attention; + RopeType rope_type; + + public: + LTXAttention() = default; + LTXAttention(int64_t query_dim, int64_t context_dim, int num_heads, int head_dim, + bool apply_gated_attention = false, float norm_eps = 1e-6f, + RopeType rope_type = RopeType::SPLIT) + : query_dim(query_dim), context_dim(context_dim), num_heads(num_heads), + head_dim(head_dim), inner_dim(static_cast(num_heads) * head_dim), + norm_eps(norm_eps), apply_gated_attention(apply_gated_attention), + rope_type(rope_type) { + blocks["to_q"] = std::make_shared(query_dim, inner_dim, true); + blocks["to_k"] = std::make_shared(context_dim, inner_dim, true); + blocks["to_v"] = std::make_shared(context_dim, inner_dim, true); + blocks["q_norm"] = std::make_shared(inner_dim, norm_eps); + blocks["k_norm"] = std::make_shared(inner_dim, norm_eps); + blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, true); + if (apply_gated_attention) { + blocks["to_gate_logits"] = std::make_shared(query_dim, num_heads, true); + } + } + + // x: [query_dim, L_q, B] + // context: [context_dim, L_kv, B] (defaults to x for self-attn) + // pe: optional packed cos/sin [inner_dim, L_q, 2] applied to Q + // mask: optional additive attention mask + // k_pe: optional separate cos/sin [inner_dim, L_kv, 2] applied to K. Null → + // K uses `pe` (same length as Q). Used for cross-modal attention where + // Q and K have different sequence lengths and per-modality positional + // embeddings (audio_to_video_attn / video_to_audio_attn). + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* pe, + ggml_tensor* mask = nullptr, + ggml_tensor* k_pe = nullptr) { + if (context == nullptr) { + context = x; + } + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto to_out = std::dynamic_pointer_cast(blocks["to_out.0"]); + + auto q = to_q->forward(ctx, x); // [inner_dim, L_q, B] + auto k = to_k->forward(ctx, context); // [inner_dim, L_kv, B] + auto v = to_v->forward(ctx, context); // [inner_dim, L_kv, B] + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + + if (pe != nullptr) { + ggml_tensor* k_pe_eff = (k_pe != nullptr) ? k_pe : pe; + if (rope_type == RopeType::SPLIT) { + auto q_cs = LTXRope::split_pe_split(ctx->ggml_ctx, pe); + q = LTXRope::apply_rotary_emb_split(ctx->ggml_ctx, q, q_cs.first, q_cs.second, num_heads); + auto k_cs = LTXRope::split_pe_split(ctx->ggml_ctx, k_pe_eff); + k = LTXRope::apply_rotary_emb_split(ctx->ggml_ctx, k, k_cs.first, k_cs.second, num_heads); + } else { + auto q_cs = LTXRope::split_pe(ctx->ggml_ctx, pe); + q = LTXRope::apply_rotary_emb_interleaved(ctx->ggml_ctx, q, q_cs.first, q_cs.second); + auto k_cs = LTXRope::split_pe(ctx->ggml_ctx, k_pe_eff); + k = LTXRope::apply_rotary_emb_interleaved(ctx->ggml_ctx, k, k_cs.first, k_cs.second); + } + } + + auto out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, + num_heads, mask, false, ctx->flash_attn_enabled); + // out: [inner_dim, L_q, B] + + if (apply_gated_attention) { + auto gate_proj = std::dynamic_pointer_cast(blocks["to_gate_logits"]); + auto gate_logits = gate_proj->forward(ctx, x); // [num_heads, L_q, B] + auto gates = ggml_sigmoid(ctx->ggml_ctx, gate_logits); + gates = ggml_scale(ctx->ggml_ctx, gates, 2.f); + // out is [inner_dim, L_q, B]; reshape to [head_dim, num_heads, L_q, B], multiply gates as [1, num_heads, L_q, B] broadcast. + int64_t L_q = out->ne[1]; + int64_t B = out->ne[2]; + auto out4 = ggml_reshape_4d(ctx->ggml_ctx, out, head_dim, num_heads, L_q, B); + auto g4 = ggml_reshape_4d(ctx->ggml_ctx, gates, 1, num_heads, L_q, B); + out4 = ggml_mul(ctx->ggml_ctx, out4, g4); + out = ggml_reshape_3d(ctx->ggml_ctx, out4, inner_dim, L_q, B); + } + + out = to_out->forward(ctx, out); // [query_dim, L_q, B] + return out; + } + }; + + // PixArtAlphaTextProjection — caption_projection inside the DiT. + // Python: ltx_core/model/transformer/text_projection.py. + // linear_1 (caption_channels → hidden) → GELU(tanh) → linear_2 (hidden → out). + // Used in V1 / 19B to bring the connector's 3840-dim output up to the DiT's + // 4096-dim inner space. In config the `caption_proj_before_connector` flag + // distinguishes V1 (True, used here) from V2 (False, handled separately). + struct PixArtAlphaTextProjection : public GGMLBlock { + protected: + int64_t in_features; + int64_t hidden_size; + int64_t out_features; + + public: + PixArtAlphaTextProjection() = default; + PixArtAlphaTextProjection(int64_t in_features, int64_t hidden_size, int64_t out_features = 0) + : in_features(in_features), hidden_size(hidden_size), + out_features(out_features == 0 ? hidden_size : out_features) { + blocks["linear_1"] = std::make_shared(in_features, hidden_size, true); + blocks["linear_2"] = std::make_shared(hidden_size, this->out_features, true); + } + + int64_t get_in_features() const { return in_features; } + int64_t get_out_features() const { return out_features; } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto l1 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["linear_2"]); + x = l1->forward(ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, /*approximate_tanh=*/true); + x = l2->forward(ctx, x); + return x; + } + }; + + // Args for one modality (video or audio) flowing through an AV transformer + // block. Mirrors python TransformerArgs but only the fields the block needs. + struct LTX2AVModalityArgs { + ggml_tensor* x = nullptr; // [dim, L, B] — set null to skip this modality + ggml_tensor* context = nullptr; // [ctx_dim, L_ctx, B] + ggml_tensor* modulation = nullptr; // [dim, 6_or_9, B] + ggml_tensor* pe = nullptr; // [inner_dim, L, 2] + ggml_tensor* prompt_modulation = nullptr; // [dim, 2, B] or null (cross_attention_adaln only) + ggml_tensor* context_mask = nullptr; // [L_ctx, L, 1, B] or null + // Cross-modal modulation tensors. Each block uses table[0:2] + cross_scale_shift_modulation[:,0:2,:] + // for its a2v slot, table[2:4] + cross_scale_shift_modulation[:,2:4,:] for its v2a slot, + // and table[4:5] + cross_gate_modulation for the gate. + ggml_tensor* cross_scale_shift_modulation = nullptr; // [dim, 4, B] + ggml_tensor* cross_gate_modulation = nullptr; // [dim, 1, B] + // Cross-modal RoPE positional embeddings. Computed at the model level + // from the union of video+audio max_pos so both sides share scale. + ggml_tensor* cross_pe = nullptr; // [inner_dim_cross, L_cross, 2] + }; + + struct LTXTransformerBlock : public GGMLBlock { + protected: + int64_t dim; + int num_heads; + int head_dim; + int64_t context_dim; + bool cross_attention_adaln; + bool apply_gated_attention; + float norm_eps; + + // --- audio-video extension --- + // When `has_audio_video == true`, the block additionally carries the + // audio-side self-attn / text-CA / FFN, and the cross-modal a2v / v2a + // attentions plus their scale_shift_table_a2v_ca_{audio,video} tables. + // This mirrors python BasicAVTransformerBlock when `audio is not None + // and video is not None`. + bool has_audio_video = false; + int64_t audio_dim = 0; + int audio_num_heads = 0; + int audio_head_dim = 0; + int64_t audio_context_dim = 0; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + int num_params = cross_attention_adaln ? ADALN_WITH_CA : ADALN_BASE; + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, num_params); + if (cross_attention_adaln) { + params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2); + } + if (has_audio_video) { + // audio_scale_shift_table mirrors video's: 6 rows (or 9 with cross_attention_adaln) + params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, num_params); + if (cross_attention_adaln) { + params["audio_prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, 2); + } + // 5-row tables: rows 0-1 = a2v scale/shift, rows 2-3 = v2a scale/shift, row 4 = gate. + params["scale_shift_table_a2v_ca_audio"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, 5); + params["scale_shift_table_a2v_ca_video"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 5); + } + } + + public: + LTXTransformerBlock() = default; + LTXTransformerBlock(int64_t dim, int num_heads, int head_dim, int64_t context_dim, + bool cross_attention_adaln = false, bool apply_gated_attention = false, + float norm_eps = 1e-6f, + RopeType rope_type = RopeType::SPLIT, + // Audio-video config — set audio_dim > 0 to enable. Defaults disable + // the audio path so existing video-only construction is unchanged. + int64_t audio_dim = 0, + int audio_num_heads = 0, + int audio_head_dim = 0, + int64_t audio_context_dim = 0) + : dim(dim), num_heads(num_heads), head_dim(head_dim), context_dim(context_dim), + cross_attention_adaln(cross_attention_adaln), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps), + has_audio_video(audio_dim > 0), + audio_dim(audio_dim), audio_num_heads(audio_num_heads), + audio_head_dim(audio_head_dim), audio_context_dim(audio_context_dim) { + blocks["attn1"] = std::make_shared(dim, dim, num_heads, head_dim, apply_gated_attention, norm_eps, rope_type); + blocks["attn2"] = std::make_shared(dim, context_dim, num_heads, head_dim, apply_gated_attention, norm_eps, rope_type); + blocks["ff"] = std::make_shared(dim, dim); + if (has_audio_video) { + // Audio self-attention + audio text cross-attention + audio FFN. + // Heads/d_head are AUDIO's (typically 32×64 vs video's 32×128). + blocks["audio_attn1"] = std::make_shared( + audio_dim, audio_dim, audio_num_heads, audio_head_dim, + apply_gated_attention, norm_eps, rope_type); + blocks["audio_attn2"] = std::make_shared( + audio_dim, audio_context_dim, audio_num_heads, audio_head_dim, + apply_gated_attention, norm_eps, rope_type); + blocks["audio_ff"] = std::make_shared(audio_dim, audio_dim); + + // Cross-modal: query_dim is the QUERYING modality's; context_dim is + // the OTHER modality's; heads/d_head come from AUDIO config (for both, + // matching python BasicAVTransformerBlock). + blocks["audio_to_video_attn"] = std::make_shared( + /*query_dim=*/dim, /*context_dim=*/audio_dim, + audio_num_heads, audio_head_dim, + apply_gated_attention, norm_eps, rope_type); + blocks["video_to_audio_attn"] = std::make_shared( + /*query_dim=*/audio_dim, /*context_dim=*/dim, + audio_num_heads, audio_head_dim, + apply_gated_attention, norm_eps, rope_type); + } + } + + // Helper — returns a triple (a, b, c) from scale_shift_table[start:start+3] + modulation[:, start:start+3, :] + // scale_shift_table: ne [dim, num_params] + // modulation: ne [dim, num_params, B] + // Returns three tensors each ne [dim, 1, B]. + std::tuple extract_triple(ggml_context* ctx, + ggml_tensor* sst, + ggml_tensor* modulation, + int start) { + int64_t B = modulation->ne[2]; + + // Slice scale_shift_table rows [start, start+3). + auto sst_slice = ggml_ext_slice(ctx, sst, 1, start, start + 3); // ne [dim, 3] + + // Slice modulation along dim 1 [start, start+3). + auto mod_slice = ggml_ext_slice(ctx, modulation, 1, start, start + 3); // ne [dim, 3, B] + + // Broadcast add: sst_slice [dim, 3] + mod_slice [dim, 3, B] → [dim, 3, B]. + auto combined = ggml_add(ctx, mod_slice, sst_slice); + + auto chunks = ggml_ext_chunk(ctx, combined, 3, 1); + // Each chunk ne [dim, 1, B] + return std::make_tuple(chunks[0], chunks[1], chunks[2]); + } + + // Extract (shift, scale) from prompt_scale_shift_table [dim, 2] + prompt_modulation [dim, 2, B]. + // Python: `(prompt_scale_shift_table[None, None] + prompt_timestep.reshape(...,2,-1)).unbind(2)`. + std::pair extract_kv_pair(ggml_context* ctx, + ggml_tensor* psst, + ggml_tensor* prompt_mod) { + auto combined = ggml_add(ctx, prompt_mod, psst); // [dim, 2, B] + auto chunks = ggml_ext_chunk(ctx, combined, 2, 1); + return {chunks[0], chunks[1]}; // (shift_kv, scale_kv), each [dim, 1, B] + } + + // Extract (scale, shift, gate) for an AV cross-modal slot. + // table: [dim, 5] — row 0/1 a2v scale/shift, row 2/3 v2a scale/shift, row 4 gate + // ss_mod: [dim, 4, B] — modulation matching the 4 scale/shift rows + // gate_mod: [dim, 1, B] — modulation for the gate row + // start: 0 for the a2v slot, 2 for the v2a slot + // Returns three tensors each ne [dim, 1, B]. + std::tuple extract_av_modulation( + ggml_context* ctx, ggml_tensor* table, ggml_tensor* ss_mod, ggml_tensor* gate_mod, int start) { + // scale,shift = table[start:start+2] + ss_mod[:, start:start+2, :] + auto sst_ss = ggml_ext_slice(ctx, table, 1, start, start + 2); // [dim, 2] + auto mod_ss = ggml_ext_slice(ctx, ss_mod, 1, start, start + 2); // [dim, 2, B] + auto sum_ss = ggml_add(ctx, mod_ss, sst_ss); // [dim, 2, B] + auto chunks_ss = ggml_ext_chunk(ctx, sum_ss, 2, 1); // 2× [dim, 1, B] + auto scale = chunks_ss[0]; + auto shift = chunks_ss[1]; + + // gate = table[4:5] + gate_mod + auto sst_g = ggml_ext_slice(ctx, table, 1, 4, 5); // [dim, 1] + auto sum_g = ggml_add(ctx, gate_mod, sst_g); // [dim, 1, B] + auto chunks_g = ggml_ext_chunk(ctx, sum_g, 1, 1); // 1× [dim, 1, B] + auto gate = chunks_g[0]; + return std::make_tuple(scale, shift, gate); + } + + // Apply text cross-attention path (V1 or V2 per cross_attention_adaln). + // Mirrors the inner block of forward() so audio-side text-CA can reuse it. + ggml_tensor* apply_text_ca(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* context_mask, + std::shared_ptr attn, + ggml_tensor* sst, + ggml_tensor* modulation, + ggml_tensor* prompt_modulation, + ggml_tensor* prompt_sst) { + if (cross_attention_adaln) { + GGML_ASSERT(prompt_modulation != nullptr && prompt_sst != nullptr); + auto triple_ca = extract_triple(ctx->ggml_ctx, sst, modulation, 6); + auto shift_q = std::get<0>(triple_ca); + auto scale_q = std::get<1>(triple_ca); + auto gate_q = std::get<2>(triple_ca); + + auto kv_pair = extract_kv_pair(ctx->ggml_ctx, prompt_sst, prompt_modulation); + auto shift_kv = kv_pair.first; + auto scale_kv = kv_pair.second; + + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto q_scaled = ggml_add(ctx->ggml_ctx, norm_x_ca, ggml_mul(ctx->ggml_ctx, norm_x_ca, scale_q)); + auto q_modulated = ggml_add(ctx->ggml_ctx, q_scaled, shift_q); + auto ctx_scaled = ggml_add(ctx->ggml_ctx, context, ggml_mul(ctx->ggml_ctx, context, scale_kv)); + auto ctx_modulated = ggml_add(ctx->ggml_ctx, ctx_scaled, shift_kv); + + auto ca_out = attn->forward(ctx, q_modulated, ctx_modulated, nullptr, context_mask); + return ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, ca_out, gate_q)); + } else { + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto ca_out = attn->forward(ctx, norm_x_ca, context, nullptr, context_mask); + return ggml_add(ctx->ggml_ctx, x, ca_out); + } + } + + // x: [dim, L_q, B] + // context: [context_dim, L_kv, B] + // modulation: [dim, num_params, B] (num_params = 6 for V1, 9 for V2) + // pe: packed cos/sin tensor [dim, L_q, 2] + // prompt_modulation: [dim, 2, B] — required when cross_attention_adaln=true, else nullptr + // context_mask: [L_kv, L_q, 1, B] additive mask (or nullptr) + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* modulation, + ggml_tensor* pe, + ggml_tensor* prompt_modulation = nullptr, + ggml_tensor* context_mask = nullptr, + // STG (Spatio-Temporal Guidance) perturbation: when true, + // bypass video self-attention entirely. Mirrors python's + // SKIP_VIDEO_SELF_ATTN with all_perturbed=True. The block's + // residual passes through unchanged for the self-attn step. + bool skip_self_attn = false) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + auto sst = params["scale_shift_table"]; + + // --- Self-attention path (modulation slice 0:3 → shift, scale, gate) --- + // Skipped entirely when skip_self_attn is set (STG perturbation pass). + if (!skip_self_attn) { + auto triple1 = extract_triple(ctx->ggml_ctx, sst, modulation, 0); + auto shift_msa = std::get<0>(triple1); + auto scale_msa = std::get<1>(triple1); + auto gate_msa = std::get<2>(triple1); + + auto norm_x = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm_x, ggml_mul(ctx->ggml_ctx, norm_x, scale_msa)); + auto modulated = ggml_add(ctx->ggml_ctx, scaled, shift_msa); + auto attn_out = attn1->forward(ctx, modulated, nullptr, pe, nullptr); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + } + + // --- Cross-attention --- + // V1 (cross_attention_adaln=false): plain rms_norm → attn2 → residual. + // V2 (cross_attention_adaln=true): + // modulation[6:9] → (q_shift, q_scale, q_gate) for the query path + // prompt_scale_shift_table + prompt_modulation → (kv_shift, kv_scale) for the context + // attn_input = rms_norm(x) * (1 + q_scale) + q_shift + // context_mod = context * (1 + kv_scale) + kv_shift + // x = x + attn2(attn_input, context_mod) * q_gate + if (cross_attention_adaln) { + GGML_ASSERT(prompt_modulation != nullptr && "cross_attention_adaln requires prompt_modulation"); + auto triple_ca = extract_triple(ctx->ggml_ctx, sst, modulation, 6); + auto shift_q = std::get<0>(triple_ca); + auto scale_q = std::get<1>(triple_ca); + auto gate_q = std::get<2>(triple_ca); + + auto psst = params["prompt_scale_shift_table"]; // [dim, 2] + auto kv_pair = extract_kv_pair(ctx->ggml_ctx, psst, prompt_modulation); + auto shift_kv = kv_pair.first; + auto scale_kv = kv_pair.second; + + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto q_scaled = ggml_add(ctx->ggml_ctx, norm_x_ca, ggml_mul(ctx->ggml_ctx, norm_x_ca, scale_q)); + auto q_modulated = ggml_add(ctx->ggml_ctx, q_scaled, shift_q); + auto ctx_scaled = ggml_add(ctx->ggml_ctx, context, ggml_mul(ctx->ggml_ctx, context, scale_kv)); + auto ctx_modulated = ggml_add(ctx->ggml_ctx, ctx_scaled, shift_kv); + + auto ca_out = attn2->forward(ctx, q_modulated, ctx_modulated, nullptr, context_mask); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, ca_out, gate_q)); + } else { + auto norm_x_ca = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto ca_out = attn2->forward(ctx, norm_x_ca, context, nullptr, context_mask); + x = ggml_add(ctx->ggml_ctx, x, ca_out); + } + + // --- FeedForward path (modulation slice 3:6 → shift, scale, gate) --- + auto triple2 = extract_triple(ctx->ggml_ctx, sst, modulation, 3); + auto shift_mlp = std::get<0>(triple2); + auto scale_mlp = std::get<1>(triple2); + auto gate_mlp = std::get<2>(triple2); + + auto norm_x2 = parameterless_rms_norm(ctx->ggml_ctx, x, norm_eps); + auto scaled_mlp = ggml_add(ctx->ggml_ctx, norm_x2, ggml_mul(ctx->ggml_ctx, norm_x2, scale_mlp)); + auto modulated_mlp = ggml_add(ctx->ggml_ctx, scaled_mlp, shift_mlp); + auto ff_out = ff->forward(ctx, modulated_mlp); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, ff_out, gate_mlp)); + + return x; + } + + // Audio-video forward. Mirrors python BasicAVTransformerBlock.forward. + // Either modality may be skipped by setting its `x = nullptr`. When + // both are provided, the cross-modal a2v / v2a paths run. + // Returns {video_out, audio_out}; either may be null when skipped. + // NOTE: perturbation masks (PerturbationType.SKIP_*) are not modeled — + // the block is always all-on. Parity tests should leave perturbations + // unset on the python side too. + std::pair forward_av(GGMLRunnerContext* ctx, + LTX2AVModalityArgs vargs, + LTX2AVModalityArgs aargs) { + GGML_ASSERT(has_audio_video && "block lacks audio-video weights"); + + const bool run_vx = (vargs.x != nullptr); + const bool run_ax = (aargs.x != nullptr); + const bool run_a2v = run_vx && run_ax; + const bool run_v2a = run_vx && run_ax; // mirrors python: both modalities present + + ggml_tensor* vx = vargs.x; + ggml_tensor* ax = aargs.x; + + auto v_sst = params["scale_shift_table"]; + auto v_prompt_sst = cross_attention_adaln ? params["prompt_scale_shift_table"] : nullptr; + auto a_sst = params["audio_scale_shift_table"]; + auto a_prompt_sst = cross_attention_adaln ? params["audio_prompt_scale_shift_table"] : nullptr; + auto sst_av_v = params["scale_shift_table_a2v_ca_video"]; // [dim, 5] + auto sst_av_a = params["scale_shift_table_a2v_ca_audio"]; // [audio_dim, 5] + + auto v_attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto v_attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto v_ff = std::dynamic_pointer_cast(blocks["ff"]); + auto a_attn1 = std::dynamic_pointer_cast(blocks["audio_attn1"]); + auto a_attn2 = std::dynamic_pointer_cast(blocks["audio_attn2"]); + auto a_ff = std::dynamic_pointer_cast(blocks["audio_ff"]); + auto a2v_attn = std::dynamic_pointer_cast(blocks["audio_to_video_attn"]); + auto v2a_attn = std::dynamic_pointer_cast(blocks["video_to_audio_attn"]); + + // === Video self-attn + text-CA === + if (run_vx) { + auto t1 = extract_triple(ctx->ggml_ctx, v_sst, vargs.modulation, 0); + auto sh = std::get<0>(t1), sc = std::get<1>(t1), ga = std::get<2>(t1); + auto norm_v = parameterless_rms_norm(ctx->ggml_ctx, vx, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm_v, ggml_mul(ctx->ggml_ctx, norm_v, sc)); + auto modul = ggml_add(ctx->ggml_ctx, scaled, sh); + auto out = v_attn1->forward(ctx, modul, nullptr, vargs.pe, nullptr); + vx = ggml_add(ctx->ggml_ctx, vx, ggml_mul(ctx->ggml_ctx, out, ga)); + vx = apply_text_ca(ctx, vx, vargs.context, vargs.context_mask, + v_attn2, v_sst, vargs.modulation, vargs.prompt_modulation, v_prompt_sst); + } + + // === Audio self-attn + text-CA === + if (run_ax) { + auto t1 = extract_triple(ctx->ggml_ctx, a_sst, aargs.modulation, 0); + auto sh = std::get<0>(t1), sc = std::get<1>(t1), ga = std::get<2>(t1); + auto norm_a = parameterless_rms_norm(ctx->ggml_ctx, ax, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm_a, ggml_mul(ctx->ggml_ctx, norm_a, sc)); + auto modul = ggml_add(ctx->ggml_ctx, scaled, sh); + auto out = a_attn1->forward(ctx, modul, nullptr, aargs.pe, nullptr); + ax = ggml_add(ctx->ggml_ctx, ax, ggml_mul(ctx->ggml_ctx, out, ga)); + ax = apply_text_ca(ctx, ax, aargs.context, aargs.context_mask, + a_attn2, a_sst, aargs.modulation, aargs.prompt_modulation, a_prompt_sst); + } + + // === Audio-Video cross-attention === + if (run_a2v || run_v2a) { + auto vx_norm3 = run_vx ? parameterless_rms_norm(ctx->ggml_ctx, vx, norm_eps) : nullptr; + auto ax_norm3 = run_ax ? parameterless_rms_norm(ctx->ggml_ctx, ax, norm_eps) : nullptr; + + if (run_a2v) { + // Q from video, K/V from audio. + auto v_av = extract_av_modulation(ctx->ggml_ctx, sst_av_v, + vargs.cross_scale_shift_modulation, + vargs.cross_gate_modulation, /*start=*/0); + auto v_scale = std::get<0>(v_av), v_shift = std::get<1>(v_av), gate_a2v = std::get<2>(v_av); + auto vx_scaled = ggml_add(ctx->ggml_ctx, vx_norm3, ggml_mul(ctx->ggml_ctx, vx_norm3, v_scale)); + vx_scaled = ggml_add(ctx->ggml_ctx, vx_scaled, v_shift); + + auto a_av = extract_av_modulation(ctx->ggml_ctx, sst_av_a, + aargs.cross_scale_shift_modulation, + aargs.cross_gate_modulation, /*start=*/0); + auto a_scale = std::get<0>(a_av), a_shift = std::get<1>(a_av); + auto ax_scaled = ggml_add(ctx->ggml_ctx, ax_norm3, ggml_mul(ctx->ggml_ctx, ax_norm3, a_scale)); + ax_scaled = ggml_add(ctx->ggml_ctx, ax_scaled, a_shift); + + // Cross-modal RoPE: Q (video) uses video.cross_pe, K (audio) uses audio.cross_pe. + auto out = a2v_attn->forward(ctx, vx_scaled, ax_scaled, + vargs.cross_pe, /*mask=*/nullptr, + /*k_pe=*/aargs.cross_pe); + vx = ggml_add(ctx->ggml_ctx, vx, ggml_mul(ctx->ggml_ctx, out, gate_a2v)); + } + + if (run_v2a) { + // Q from audio, K/V from video. + auto a_av = extract_av_modulation(ctx->ggml_ctx, sst_av_a, + aargs.cross_scale_shift_modulation, + aargs.cross_gate_modulation, /*start=*/2); + auto a_scale = std::get<0>(a_av), a_shift = std::get<1>(a_av), gate_v2a = std::get<2>(a_av); + auto ax_scaled = ggml_add(ctx->ggml_ctx, ax_norm3, ggml_mul(ctx->ggml_ctx, ax_norm3, a_scale)); + ax_scaled = ggml_add(ctx->ggml_ctx, ax_scaled, a_shift); + + auto v_av = extract_av_modulation(ctx->ggml_ctx, sst_av_v, + vargs.cross_scale_shift_modulation, + vargs.cross_gate_modulation, /*start=*/2); + auto v_scale = std::get<0>(v_av), v_shift = std::get<1>(v_av); + auto vx_scaled = ggml_add(ctx->ggml_ctx, vx_norm3, ggml_mul(ctx->ggml_ctx, vx_norm3, v_scale)); + vx_scaled = ggml_add(ctx->ggml_ctx, vx_scaled, v_shift); + + auto out = v2a_attn->forward(ctx, ax_scaled, vx_scaled, + aargs.cross_pe, /*mask=*/nullptr, + /*k_pe=*/vargs.cross_pe); + ax = ggml_add(ctx->ggml_ctx, ax, ggml_mul(ctx->ggml_ctx, out, gate_v2a)); + } + } + + // === Video FF === + if (run_vx) { + auto t = extract_triple(ctx->ggml_ctx, v_sst, vargs.modulation, 3); + auto sh = std::get<0>(t), sc = std::get<1>(t), ga = std::get<2>(t); + auto norm = parameterless_rms_norm(ctx->ggml_ctx, vx, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm, ggml_mul(ctx->ggml_ctx, norm, sc)); + auto modul = ggml_add(ctx->ggml_ctx, scaled, sh); + auto out = v_ff->forward(ctx, modul); + vx = ggml_add(ctx->ggml_ctx, vx, ggml_mul(ctx->ggml_ctx, out, ga)); + } + + // === Audio FF === + if (run_ax) { + auto t = extract_triple(ctx->ggml_ctx, a_sst, aargs.modulation, 3); + auto sh = std::get<0>(t), sc = std::get<1>(t), ga = std::get<2>(t); + auto norm = parameterless_rms_norm(ctx->ggml_ctx, ax, norm_eps); + auto scaled = ggml_add(ctx->ggml_ctx, norm, ggml_mul(ctx->ggml_ctx, norm, sc)); + auto modul = ggml_add(ctx->ggml_ctx, scaled, sh); + auto out = a_ff->forward(ctx, modul); + ax = ggml_add(ctx->ggml_ctx, ax, ggml_mul(ctx->ggml_ctx, out, ga)); + } + + return {vx, ax}; + } + }; + + struct LTXParams { + int64_t in_channels = 128; + int64_t out_channels = 128; + int64_t inner_dim = 4096; + int num_heads = 32; + int head_dim = 128; + int num_layers = 48; + int64_t cross_attention_dim = 4096; + bool cross_attention_adaln = false; + bool apply_gated_attention = false; + float norm_eps = 1e-6f; + float positional_embedding_theta = 10000.f; + std::vector positional_embedding_max_pos = {20, 2048, 2048}; + float timestep_scale_multiplier = 1000.f; + bool use_middle_indices_grid = true; + RopeType rope_type = RopeType::SPLIT; // real LTX-2.3 default + // Optional caption_projection sitting on the DiT side (V1 / 19B); absent for + // tiny parity tests that feed context in DiT inner_dim already. When enabled, + // `caption_channels` is the input dim (connector output) and `caption_hidden` + // / `caption_out` follow the PixArtAlphaTextProjection defaults. + bool has_caption_projection = false; + int64_t caption_channels = 0; + int64_t caption_hidden = 0; + int64_t caption_out = 0; + + // ---- Audio-video extension (model_type=AudioVideo). ---- + // Set `has_audio_video=true` to enable the audio side end-to-end: + // audio_patchify_proj, audio_adaln_single, audio_caption_projection, + // audio_norm_out / audio_scale_shift_table / audio_proj_out, plus the + // four cross-modal AdaLN modules (av_ca_*). Each transformer block + // also gets its audio-side weights and cross-modal a2v / v2a tables. + bool has_audio_video = false; + int64_t audio_in_channels = 128; + int64_t audio_out_channels = 128; + int64_t audio_inner_dim = 2048; // 32 × 64 for 22B + int audio_num_heads = 32; + int audio_head_dim = 64; + int64_t audio_cross_attention_dim = 2048; + std::vector audio_positional_embedding_max_pos = {20}; + float av_ca_timestep_scale_multiplier = 1.f; + // Audio-side caption projection (rare — most checkpoints don't carry it). + bool has_audio_caption_projection = false; + int64_t audio_caption_channels = 0; + int64_t audio_caption_hidden = 0; + int64_t audio_caption_out = 0; + }; + + struct LTXModel : public GGMLBlock { + LTXParams p; + + protected: + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, p.inner_dim, 2); + if (p.has_audio_video) { + params["audio_scale_shift_table"] = + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, p.audio_inner_dim, 2); + } + } + + public: + LTXModel() = default; + LTXModel(LTXParams p) : p(p) { + blocks["patchify_proj"] = std::make_shared(p.in_channels, p.inner_dim, true); + int coeff = p.cross_attention_adaln ? ADALN_WITH_CA : ADALN_BASE; + blocks["adaln_single"] = std::make_shared(p.inner_dim, coeff); + blocks["proj_out"] = std::make_shared(p.inner_dim, p.out_channels, true); + + // V2: a second AdaLayerNormSingle that generates modulation for the + // context path inside cross-attention. Python: + // `prompt_adaln_single = AdaLayerNormSingle(inner_dim, embedding_coefficient=2)`. + if (p.cross_attention_adaln) { + blocks["prompt_adaln_single"] = std::make_shared(p.inner_dim, 2); + } + + // Audio-video components (model_type=AudioVideo). + if (p.has_audio_video) { + blocks["audio_patchify_proj"] = + std::make_shared(p.audio_in_channels, p.audio_inner_dim, true); + blocks["audio_adaln_single"] = + std::make_shared(p.audio_inner_dim, coeff); + blocks["audio_proj_out"] = + std::make_shared(p.audio_inner_dim, p.audio_out_channels, true); + if (p.cross_attention_adaln) { + blocks["audio_prompt_adaln_single"] = + std::make_shared(p.audio_inner_dim, 2); + } + if (p.has_audio_caption_projection) { + blocks["audio_caption_projection"] = std::make_shared( + p.audio_caption_channels, p.audio_caption_hidden, p.audio_caption_out); + } + // Cross-modal AdaLN modules. Coefficients per python LTXModel._init_audio_video: + // av_ca_video_scale_shift: 4 for the (a2v_scale, a2v_shift, v2a_scale, v2a_shift) row pack. + // av_ca_audio_scale_shift: same shape on audio side. + // av_ca_a2v_gate: 1 (single gate for video Q in a2v). + // av_ca_v2a_gate: 1 (single gate for audio Q in v2a). + blocks["av_ca_video_scale_shift_adaln_single"] = + std::make_shared(p.inner_dim, 4); + blocks["av_ca_audio_scale_shift_adaln_single"] = + std::make_shared(p.audio_inner_dim, 4); + blocks["av_ca_a2v_gate_adaln_single"] = + std::make_shared(p.inner_dim, 1); + blocks["av_ca_v2a_gate_adaln_single"] = + std::make_shared(p.audio_inner_dim, 1); + } + + for (int i = 0; i < p.num_layers; ++i) { + if (p.has_audio_video) { + blocks["transformer_blocks." + std::to_string(i)] = + std::make_shared(p.inner_dim, p.num_heads, p.head_dim, + p.cross_attention_dim, + p.cross_attention_adaln, + p.apply_gated_attention, + p.norm_eps, + p.rope_type, + p.audio_inner_dim, + p.audio_num_heads, + p.audio_head_dim, + p.audio_cross_attention_dim); + } else { + blocks["transformer_blocks." + std::to_string(i)] = + std::make_shared(p.inner_dim, p.num_heads, p.head_dim, + p.cross_attention_dim, + p.cross_attention_adaln, + p.apply_gated_attention, + p.norm_eps, + p.rope_type); + } + } + + if (p.has_caption_projection) { + blocks["caption_projection"] = std::make_shared( + p.caption_channels, p.caption_hidden, p.caption_out); + } + } + + // latent: ne [in_channels, T*H*W, B] (already patchified by caller) + // timestep: ne [B] + // context: ne [cross_attention_dim, S, B] + // pe: ne [inner_dim, T*H*W, 2] (interleaved cos/sin) + // context_mask: ne [S, T*H*W, 1, B] or nullptr + // Returns: ne [out_channels, T*H*W, B] + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* timestep, + ggml_tensor* context, + ggml_tensor* pe, + ggml_tensor* context_mask = nullptr, + // STG (Spatio-Temporal Guidance): block indices whose video + // self-attention is bypassed during the perturbed pass. + // Empty by default — passing a non-empty set produces a + // weakened prediction used by the guider's stg_scale term. + const std::vector* stg_skip_blocks = nullptr) { + auto patchify_proj = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto adaln_single = std::dynamic_pointer_cast(blocks["adaln_single"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + + // Apply caption_projection (V1 / 19B) to lift context from connector dim + // to DiT inner_dim. Python: TransformerArgs._prepare_context. + if (p.has_caption_projection && context != nullptr) { + auto caption_proj = std::dynamic_pointer_cast(blocks["caption_projection"]); + context = caption_proj->forward(ctx, context); + } + + auto x = patchify_proj->forward(ctx, latent); // [inner_dim, T*H*W, B] + + // Caller must feed the already-scaled timestep (σ * 1000). The LTX2 denoiser's sigma_to_t + // is the single source of truth for that scaling — see LTXParams::timestep_scale_multiplier + // which is kept as documentation/config only, not applied here. + auto adaln_res = adaln_single->forward(ctx, timestep); + auto modulation = adaln_res.first; // [inner_dim, coeff, B] (coeff = 6 or 9) + auto embedded_t = adaln_res.second; // [inner_dim, B] + + // V2: prompt_adaln_single takes the same σ (raw timestep before AdaLN-scaling) + // and emits a [inner_dim, 2, B] modulation that's shared across all blocks' + // cross-attention kv path. In Python video_args_preprocessor passes + // `modality.sigma`; for our single-prompt inference sigma == timestep. We reuse + // the same timestep tensor here. + ggml_tensor* prompt_modulation = nullptr; + if (p.cross_attention_adaln) { + auto prompt_adaln = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); + auto prompt_res = prompt_adaln->forward(ctx, timestep); + prompt_modulation = prompt_res.first; // [inner_dim, 2, B] + } + + for (int i = 0; i < p.num_layers; ++i) { + auto block = std::dynamic_pointer_cast( + blocks["transformer_blocks." + std::to_string(i)]); + bool skip_self_attn = false; + if (stg_skip_blocks != nullptr) { + for (int b : *stg_skip_blocks) { + if (b == i) { skip_self_attn = true; break; } + } + } + x = block->forward(ctx, x, context, modulation, pe, prompt_modulation, + context_mask, skip_self_attn); + } + + // Output modulation: python has `sst[None,None] + embedded[:,:,None]` giving (B, 1, 2, dim). + // In ggml ne that's [dim, 2, 1, B]. For B>1 we'd need to broadcast sst over B explicitly; + // current parity test uses B=1 so we pick the direct add path here and rely on ggml's + // ggml_can_repeat(b, a) — `a` must be >= `b` in every dim so we put sst first. + // sst ne: [inner_dim, 2, 1, 1] + // embedded: [inner_dim, 1, 1, B] (after reshape_4d from [inner_dim, B]) + // sum: [inner_dim, 2, 1, B] (provided B == 1; see TODO for B>1) + int64_t B = x->ne[2]; + GGML_ASSERT(B == 1 && "LTXModel output modulation currently assumes batch=1"); + auto sst = params["scale_shift_table"]; // ne [inner_dim, 2] + auto emb_view = ggml_reshape_4d(ctx->ggml_ctx, embedded_t, p.inner_dim, 1, 1, B); // ne [inner_dim, 1, 1, B] + auto ss_sum = ggml_add(ctx->ggml_ctx, sst, emb_view); // ne [inner_dim, 2, 1, 1] + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, ss_sum, 2, 1); // 2× ne [inner_dim, 1, 1, 1] + auto shift = ggml_reshape_3d(ctx->ggml_ctx, chunks[0], p.inner_dim, 1, 1); // ne [inner_dim, 1, 1] + auto scale = ggml_reshape_3d(ctx->ggml_ctx, chunks[1], p.inner_dim, 1, 1); // ne [inner_dim, 1, 1] + + x = ggml_ext_layer_norm(ctx->ggml_ctx, x, nullptr, nullptr, p.norm_eps); // param-less LN + + // x ne: [inner_dim, T, 1]; scale/shift ne: [inner_dim, 1, 1] — second arg broadcasts ok. + x = ggml_add(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x, scale)), shift); + x = proj_out->forward(ctx, x); // [out_channels, T*H*W, B] + return x; + } + + // Helper for the output head: param-less LayerNorm + (sst[None, None] + + // embedded[..., None]) modulation + linear projection. Mirrors python + // LTXModel._process_output. B==1 only (matches existing forward()). + ggml_tensor* process_output_head(GGMLRunnerContext* ctx, + ggml_tensor* sst, // [dim, 2] + ggml_tensor* x, // [dim, T, B=1] + ggml_tensor* embedded_t, // [dim, B=1] + std::shared_ptr proj_out, + int64_t dim) { + int64_t B = x->ne[2]; + GGML_ASSERT(B == 1 && "LTXModel output modulation currently assumes batch=1"); + auto emb_view = ggml_reshape_4d(ctx->ggml_ctx, embedded_t, dim, 1, 1, B); + auto ss_sum = ggml_add(ctx->ggml_ctx, sst, emb_view); + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, ss_sum, 2, 1); + auto shift = ggml_reshape_3d(ctx->ggml_ctx, chunks[0], dim, 1, 1); + auto scale = ggml_reshape_3d(ctx->ggml_ctx, chunks[1], dim, 1, 1); + x = ggml_ext_layer_norm(ctx->ggml_ctx, x, nullptr, nullptr, p.norm_eps); + x = ggml_add(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x, scale)), shift); + x = proj_out->forward(ctx, x); + return x; + } + + // Audio-video forward — mirrors python LTXModel.forward when both video + // and audio modalities are provided. Returns {video_out, audio_out}, + // each in shape [out_channels, T, B]. + // + // Inputs: + // *_latent: pre-patchify modality latent ([in_channels, T, B]) + // *_t_self: σ·timestep_scale_multiplier — fed to {video,audio}_adaln_single + // *_t_prompt_self: same scaling — fed to *_prompt_adaln_single (only when + // cross_attention_adaln=true; pass nullptr otherwise) + // *_t_cross_ss: cross-modality σ·timestep_scale_multiplier — fed to + // av_ca_{video,audio}_scale_shift_adaln_single + // *_t_cross_gate: cross-modality σ·av_ca_timestep_scale_multiplier — fed + // to av_ca_{a2v,v2a}_gate_adaln_single + // *_context: text encoder output (if has_caption_projection, this is + // the unprojected version; otherwise must already be in inner_dim space) + // *_pe: per-modality positional embeddings ([inner_dim, T, 2]) + // *_cross_pe: per-modality cross-modal positional embeddings sized to + // audio_inner_dim ([audio_inner_dim, T, 2]) + // *_context_mask: optional additive log-bias mask + std::pair forward_av( + GGMLRunnerContext* ctx, + ggml_tensor* v_latent, ggml_tensor* a_latent, + ggml_tensor* v_t_self, ggml_tensor* a_t_self, + ggml_tensor* v_t_prompt_self, ggml_tensor* a_t_prompt_self, + ggml_tensor* v_t_cross_ss, ggml_tensor* a_t_cross_ss, + ggml_tensor* v_t_cross_gate, ggml_tensor* a_t_cross_gate, + ggml_tensor* v_context, ggml_tensor* a_context, + ggml_tensor* v_pe, ggml_tensor* a_pe, + ggml_tensor* v_cross_pe, ggml_tensor* a_cross_pe, + ggml_tensor* v_context_mask = nullptr, + ggml_tensor* a_context_mask = nullptr) { + GGML_ASSERT(p.has_audio_video && "LTXModel was not configured for audio-video"); + + // ---- Video patchify + caption projection ---- + auto v_patchify = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto v_adaln = std::dynamic_pointer_cast(blocks["adaln_single"]); + auto v_proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + if (p.has_caption_projection && v_context != nullptr) { + auto cp = std::dynamic_pointer_cast(blocks["caption_projection"]); + v_context = cp->forward(ctx, v_context); + } + auto vx = v_patchify->forward(ctx, v_latent); + auto v_adaln_res = v_adaln->forward(ctx, v_t_self); + auto v_modulation = v_adaln_res.first; + auto v_embedded_timestep = v_adaln_res.second; + ggml_tensor* v_prompt_modulation = nullptr; + if (p.cross_attention_adaln) { + GGML_ASSERT(v_t_prompt_self != nullptr); + auto v_prompt_adaln = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); + v_prompt_modulation = v_prompt_adaln->forward(ctx, v_t_prompt_self).first; + } + + // ---- Audio patchify + caption projection ---- + auto a_patchify = std::dynamic_pointer_cast(blocks["audio_patchify_proj"]); + auto a_adaln = std::dynamic_pointer_cast(blocks["audio_adaln_single"]); + auto a_proj_out = std::dynamic_pointer_cast(blocks["audio_proj_out"]); + if (p.has_audio_caption_projection && a_context != nullptr) { + auto cp = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); + a_context = cp->forward(ctx, a_context); + } + auto ax = a_patchify->forward(ctx, a_latent); + auto a_adaln_res = a_adaln->forward(ctx, a_t_self); + auto a_modulation = a_adaln_res.first; + auto a_embedded_timestep = a_adaln_res.second; + ggml_tensor* a_prompt_modulation = nullptr; + if (p.cross_attention_adaln) { + GGML_ASSERT(a_t_prompt_self != nullptr); + auto a_prompt_adaln = std::dynamic_pointer_cast(blocks["audio_prompt_adaln_single"]); + a_prompt_modulation = a_prompt_adaln->forward(ctx, a_t_prompt_self).first; + } + + // ---- Cross-modal AdaLN modulations (one set per modality) ---- + auto v_cross_ss_adaln = std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"]); + auto a_cross_ss_adaln = std::dynamic_pointer_cast(blocks["av_ca_audio_scale_shift_adaln_single"]); + auto v_cross_gate_adaln = std::dynamic_pointer_cast(blocks["av_ca_a2v_gate_adaln_single"]); + auto a_cross_gate_adaln = std::dynamic_pointer_cast(blocks["av_ca_v2a_gate_adaln_single"]); + auto v_cross_ss_mod = v_cross_ss_adaln->forward(ctx, v_t_cross_ss).first; // [inner_dim, 4, B] + auto a_cross_ss_mod = a_cross_ss_adaln->forward(ctx, a_t_cross_ss).first; // [audio_inner_dim, 4, B] + auto v_cross_gate_mod = v_cross_gate_adaln->forward(ctx, v_t_cross_gate).first; // [inner_dim, 1, B] + auto a_cross_gate_mod = a_cross_gate_adaln->forward(ctx, a_t_cross_gate).first; // [audio_inner_dim, 1, B] + + // ---- Run all transformer blocks ---- + for (int i = 0; i < p.num_layers; ++i) { + auto block = std::dynamic_pointer_cast( + blocks["transformer_blocks." + std::to_string(i)]); + LTX2AVModalityArgs vargs; + vargs.x = vx; vargs.context = v_context; vargs.modulation = v_modulation; + vargs.pe = v_pe; vargs.cross_pe = v_cross_pe; + vargs.prompt_modulation = v_prompt_modulation; + vargs.context_mask = v_context_mask; + vargs.cross_scale_shift_modulation = v_cross_ss_mod; + vargs.cross_gate_modulation = v_cross_gate_mod; + + LTX2AVModalityArgs aargs; + aargs.x = ax; aargs.context = a_context; aargs.modulation = a_modulation; + aargs.pe = a_pe; aargs.cross_pe = a_cross_pe; + aargs.prompt_modulation = a_prompt_modulation; + aargs.context_mask = a_context_mask; + aargs.cross_scale_shift_modulation = a_cross_ss_mod; + aargs.cross_gate_modulation = a_cross_gate_mod; + + auto outs = block->forward_av(ctx, vargs, aargs); + vx = outs.first; + ax = outs.second; + } + + // ---- Output heads ---- + auto v_sst = params["scale_shift_table"]; + auto a_sst = params["audio_scale_shift_table"]; + auto v_out = process_output_head(ctx, v_sst, vx, v_embedded_timestep, v_proj_out, p.inner_dim); + auto a_out = process_output_head(ctx, a_sst, ax, a_embedded_timestep, a_proj_out, p.audio_inner_dim); + return {v_out, a_out}; + } + }; + + struct LTXRunner : public GGMLRunner { + public: + LTXParams ltx_params; + LTXModel ltx; + std::vector pe_vec; + SDVersion version; + // fps used for temporal RoPE normalisation — see LTXRope::gen_video_positions. + // Defaults to 24 (LTX-2's canonical output fps); callers can override before compute(). + float fps = 24.0f; + // VAE spatiotemporal compression factors (time, height, width) applied to latent + // coordinates to reconstruct the pixel-space positions used for RoPE. Defaults match + // the LTX-2 22B VAE: 8× temporal, 32× spatial. The parity tests feed the Python model + // simplified positions (f/fps, h, w) — set scale_factors={1,1,1} and causal_fix=false + // in that path to keep parity assertions valid. + std::vector scale_factors = {8, 32, 32}; + bool causal_fix = true; + + void set_fps(float new_fps) { fps = new_fps; } + void set_scale_factors(int time, int height, int width) { + scale_factors = {time, height, width}; + } + void set_causal_fix(bool enable) { causal_fix = enable; } + + // params_override forces the given LTXParams instead of auto-detecting from the tensor map. + // Useful for parity tests and for cases where metadata pins the head_dim / num_heads to + // values that can't be inferred from weight shapes alone (q_norm etc. are inner_dim-wide). + LTXRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_LTX2, + const LTXParams* params_override = nullptr) + : GGMLRunner(backend, offload_params_to_cpu), version(version) { + if (params_override != nullptr) { + ltx_params = *params_override; + } else { + detect_params(tensor_storage_map, prefix); + } + ltx = LTXModel(ltx_params); + ltx.init(params_ctx, tensor_storage_map, prefix); + } + + void detect_params(const String2TensorStorage& tensor_storage_map, const std::string& prefix) { + std::string pre = prefix.empty() ? "" : prefix + "."; + + auto patchify_it = tensor_storage_map.find(pre + "patchify_proj.weight"); + if (patchify_it != tensor_storage_map.end()) { + const auto& ts = patchify_it->second; + if (ts.n_dims >= 2) { + ltx_params.in_channels = ts.ne[0]; + ltx_params.inner_dim = ts.ne[1]; + } + } + + auto proj_out_it = tensor_storage_map.find(pre + "proj_out.weight"); + if (proj_out_it != tensor_storage_map.end()) { + const auto& ts = proj_out_it->second; + if (ts.n_dims >= 2) { + ltx_params.out_channels = ts.ne[1]; + } + } + + // Infer num_layers from highest transformer_blocks index. + int max_layer = -1; + std::string block_prefix = pre + "transformer_blocks."; + for (auto& pair : tensor_storage_map) { + const std::string& name = pair.first; + if (name.rfind(block_prefix, 0) != 0) { + continue; + } + size_t start = block_prefix.size(); + size_t end = name.find('.', start); + if (end == std::string::npos) { + continue; + } + try { + int idx = std::stoi(name.substr(start, end - start)); + max_layer = std::max(max_layer, idx); + } catch (...) { + } + } + if (max_layer >= 0) { + ltx_params.num_layers = max_layer + 1; + } + + // Detect cross_attention_adaln from the size of scale_shift_table (9 if CA-AdaLN, 6 otherwise). + auto sst_it = tensor_storage_map.find(pre + "transformer_blocks.0.scale_shift_table"); + if (sst_it != tensor_storage_map.end()) { + const auto& ts = sst_it->second; + if (ts.n_dims >= 2 && ts.ne[1] == ADALN_WITH_CA) { + ltx_params.cross_attention_adaln = true; + } + } + + // Infer head_dim × num_heads from attn1.to_q.weight shape. + auto q_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn1.to_q.weight"); + if (q_it != tensor_storage_map.end()) { + const auto& ts = q_it->second; + if (ts.n_dims >= 2) { + ltx_params.inner_dim = ts.ne[1]; + } + } + // head_dim is a fixed LTX-2 hyperparam (128) unless a config tensor overrides. + ltx_params.head_dim = 128; + ltx_params.num_heads = static_cast(ltx_params.inner_dim / ltx_params.head_dim); + + // Infer cross_attention_dim from attn2.to_k weight shape. + auto k_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn2.to_k.weight"); + if (k_it != tensor_storage_map.end()) { + const auto& ts = k_it->second; + if (ts.n_dims >= 2) { + ltx_params.cross_attention_dim = ts.ne[0]; + } + } + + // Detect gated attention from presence of to_gate_logits. + auto gate_it = tensor_storage_map.find(pre + "transformer_blocks.0.attn1.to_gate_logits.weight"); + if (gate_it != tensor_storage_map.end()) { + ltx_params.apply_gated_attention = true; + } + + // Detect optional caption_projection (V1 / 19B). + // linear_1 weight shape [in_features, hidden_size]; linear_2 shape [hidden_size, out_features]. + // (ggml ne[0] = innermost dim = PyTorch's in_features / hidden_size.) + auto cap1_it = tensor_storage_map.find(pre + "caption_projection.linear_1.weight"); + auto cap2_it = tensor_storage_map.find(pre + "caption_projection.linear_2.weight"); + if (cap1_it != tensor_storage_map.end() && cap2_it != tensor_storage_map.end()) { + const auto& l1 = cap1_it->second; + const auto& l2 = cap2_it->second; + if (l1.n_dims >= 2 && l2.n_dims >= 2) { + ltx_params.has_caption_projection = true; + ltx_params.caption_channels = l1.ne[0]; + ltx_params.caption_hidden = l1.ne[1]; + ltx_params.caption_out = l2.ne[1]; + } + } + } + + std::string get_desc() override { + return "ltx2"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + ltx.get_param_tensors(tensors, prefix); + } + + // Build the diffusion graph. + // x_tensor layout (ggml ne order): [W, H, T, in_channels] — follows the Wan / video convention with implicit batch N=1. + // timesteps: ne [N] + // context: ne [cross_attention_dim, S, N] + // context_mask: empty (not yet wired through) + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor, + const sd::Tensor& context_mask_tensor, + const std::vector* stg_skip_blocks = nullptr) { + ggml_cgraph* gf = new_graph_custom(LTX_GRAPH_SIZE); + + ggml_tensor* x = make_input(x_tensor); + ggml_tensor* timesteps = make_input(timesteps_tensor); + ggml_tensor* context = make_input(context_tensor); + ggml_tensor* ctx_mask = context_mask_tensor.empty() ? nullptr : make_input(context_mask_tensor); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t T = x->ne[2]; + int64_t C = x->ne[3]; + + LOG_DEBUG("LTX build_graph: x=[%lld,%lld,%lld,%lld] timesteps=[%lld] context=[%lld,%lld,%lld] inner_dim=%lld cross_attn_dim=%lld has_cap_proj=%d ca_adaln=%d gated=%d", + (long long)x->ne[0], (long long)x->ne[1], (long long)x->ne[2], (long long)x->ne[3], + (long long)timesteps->ne[0], + (long long)context->ne[0], (long long)context->ne[1], (long long)context->ne[2], + (long long)ltx_params.inner_dim, (long long)ltx_params.cross_attention_dim, + ltx_params.has_caption_projection ? 1 : 0, + ltx_params.cross_attention_adaln ? 1 : 0, + ltx_params.apply_gated_attention ? 1 : 0); + + // Flatten spatiotemporal dims into a sequence and move channels to ne[0]. + auto latent = ggml_reshape_3d(compute_ctx, x, W * H * T, C, 1); // [W*H*T, C, 1] + latent = ggml_cont(compute_ctx, ggml_permute(compute_ctx, latent, 1, 0, 2, 3)); // [C, W*H*T, 1] + + auto positions = LTXRope::gen_video_positions(static_cast(T), static_cast(H), static_cast(W), + ltx_params.use_middle_indices_grid, fps, + scale_factors, causal_fix); + ggml_tensor* pe = nullptr; + if (ltx_params.rope_type == RopeType::SPLIT) { + pe_vec = LTXRope::precompute_freqs_cis_split(positions, + static_cast(ltx_params.inner_dim), + ltx_params.num_heads, + ltx_params.positional_embedding_theta, + ltx_params.positional_embedding_max_pos); + // Split layout ne: [head_dim/2, num_heads, T*H*W, 2]. + int64_t half = ltx_params.inner_dim / 2; + int64_t per_head_half = half / ltx_params.num_heads; + pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, + per_head_half, ltx_params.num_heads, T * H * W, 2); + } else { + pe_vec = LTXRope::precompute_freqs_cis_interleaved(positions, + static_cast(ltx_params.inner_dim), + ltx_params.positional_embedding_theta, + ltx_params.positional_embedding_max_pos); + pe = ggml_new_tensor_3d(compute_ctx, GGML_TYPE_F32, ltx_params.inner_dim, T * H * W, 2); + } + set_backend_tensor_data(pe, pe_vec.data()); + + auto runner_ctx = get_context(); + ggml_tensor* out = ltx.forward(&runner_ctx, latent, timesteps, context, pe, ctx_mask, + stg_skip_blocks); + + // out: [out_channels, T*H*W, 1] → [W, H, T, out_channels] to match Wan-style output. + out = ggml_cont(compute_ctx, ggml_permute(compute_ctx, out, 1, 0, 2, 3)); // [T*H*W, out_channels, 1] + out = ggml_reshape_4d(compute_ctx, out, W, H, T, ltx_params.out_channels); + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context, + const sd::Tensor& context_mask, + const std::vector* stg_skip_blocks = nullptr) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context, context_mask, stg_skip_blocks); + }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } + }; + +} // namespace LTX + +#endif // __LTX_HPP__ diff --git a/src/ltx_connector.hpp b/src/ltx_connector.hpp new file mode 100644 index 000000000..c27276359 --- /dev/null +++ b/src/ltx_connector.hpp @@ -0,0 +1,632 @@ +#ifndef __LTX_CONNECTOR_HPP__ +#define __LTX_CONNECTOR_HPP__ + +#include +#include +#include +#include + +#include "ggml_extend.hpp" +#include "ltx.hpp" +#include "ltx_rope.hpp" +#include "model.h" + +// 1D position generator for the connector's RoPE (n_pos_dims=1, max_pos=[1], +// positions[t] = t). Lives here so it sits next to its only caller, but stays +// in the LTXRope namespace. +namespace LTXRope { + __STATIC_INLINE__ std::vector> gen_1d_positions(int T) { + std::vector> pos(1, std::vector(T, 0.f)); + for (int t = 0; t < T; ++t) pos[0][t] = static_cast(t); + return pos; + } +} // namespace LTXRope + +// LTX-2 text connector (Phase 9.1, V1 / 19B). +// +// Python reference: +// ltx_core/text_encoders/gemma/feature_extractor.py (FeatureExtractorV1) +// ltx_core/text_encoders/gemma/embeddings_connector.py (Embeddings1DConnector) +// ltx_core/model/transformer/text_projection.py (PixArtAlphaTextProjection) +// +// Pipeline (Gemma 49-layer stack → DiT cross-attention context): +// stacked[B, T, D, L] → feature_extractor_normalize() (CPU, per-(B,L) masked +// mean/range → normed[B, T, D*L]) +// normed[B, T, D*L] → FeatureExtractorV1::forward (aggregate_embed Linear) +// → video_features[B, T, inner_dim] +// video_features → Embeddings1DConnector::forward (2× BasicTransformerBlock1D +// + final rms_norm) → [B, T, inner_dim] +// connector_out → PixArtAlphaTextProjection::forward (linear, gelu_tanh, +// linear) → [B, T, caption_out_dim] (= DiT inner_dim) + +namespace LTXConnector { + + // Compute FeatureExtractorV1's _norm_and_concat_padded_batch on the CPU. + // Python reference: _norm_and_concat_padded_batch in feature_extractor.py. + // + // Input: + // stacked: [B*T*D*L] contiguous, logical shape [B, T, D, L] + // seq_lengths: [B] — valid (non-pad) token count per batch + // padding_side: "left" or "right" + // Output: + // normed: [B*T*(D*L)] contiguous, logical shape [B, T, D*L] + // + // Padded positions (outside [0, seq_len) for "right", outside [T - seq_len, T) for "left") + // are zero'd after the normalization. + __STATIC_INLINE__ void feature_extractor_normalize(const float* stacked, + const int* seq_lengths, + float* normed, + int B, int T, int D, int L, + const std::string& padding_side = "left", + float eps = 1e-6f) { + const float FINF = std::numeric_limits::infinity(); + const float NINF = -FINF; + const bool is_left = (padding_side == "left"); + + for (int b = 0; b < B; ++b) { + int seq_len = seq_lengths[b]; + int valid_start = is_left ? (T - seq_len) : 0; + int valid_end = is_left ? T : seq_len; + + for (int l = 0; l < L; ++l) { + // Compute per-(b,l) masked mean, min, max over (t, d) where mask == 1. + double sum = 0.0; + float vmin = FINF; + float vmax = NINF; + for (int t = valid_start; t < valid_end; ++t) { + for (int d = 0; d < D; ++d) { + // Python layout: encoded[b, t, d, l] + // Flat index with ne [L, D, T, B] order would be ((b*T + t)*D + d)*L + l. + int64_t idx = ((static_cast(b) * T + t) * D + d) * L + l; + float v = stacked[idx]; + sum += v; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + } + double denom = static_cast(seq_len) * D; + float mean = static_cast(sum / (denom + eps)); + float range = vmax - vmin; + float inv = 8.0f / (range + eps); + + // Apply normalization over all T positions; zero out padded ones. + for (int t = 0; t < T; ++t) { + bool in_valid = (t >= valid_start && t < valid_end); + for (int d = 0; d < D; ++d) { + int64_t src_idx = ((static_cast(b) * T + t) * D + d) * L + l; + // normed layout: [B, T, D*L] with flat index (b*T + t)*(D*L) + (d*L + l). + int64_t dst_idx = (static_cast(b) * T + t) * (D * L) + (d * L + l); + if (in_valid) { + normed[dst_idx] = (stacked[src_idx] - mean) * inv; + } else { + normed[dst_idx] = 0.0f; + } + } + } + } + } + } + + // Per-token RMSNorm used by FeatureExtractorV2 (22B / V2 text path). Mirrors + // norm_and_concat_per_token_rms in Python feature_extractor.py. + // + // Input layout (ggml ne order, matches llm->compute_all_hidden_states): + // stacked[l + L*(d + D*(t + T*b))] — logical shape [B, T, D, L] + // + // Output layout (ggml ne order): + // normed[k + (D*L)*(t + T*b)] — logical shape [B, T, D*L] with k = d*L + l + // + // Per-(B, T, L) variance is computed over D; every entry is scaled by the + // corresponding rsqrt(var + eps). Padded positions (per `attention_mask`) get + // zeroed out post-reshape, matching Python's `torch.where(mask_3d, normed, 0)`. + // + // The result is NOT yet rescaled by sqrt(target/source) — that's applied as a + // `ggml_scale` in the graph immediately before the aggregate_embed Linear so + // video and audio branches (with different target dims) can share this buffer. + __STATIC_INLINE__ void feature_extractor_normalize_v2(const float* stacked, + const int* seq_lengths, + float* normed, + int B, int T, int D, int L, + const std::string& padding_side = "left", + float eps = 1e-6f) { + const bool is_left = (padding_side == "left"); + for (int b = 0; b < B; ++b) { + int seq_len = seq_lengths[b]; + int valid_start = is_left ? (T - seq_len) : 0; + int valid_end = is_left ? T : seq_len; + + for (int t = 0; t < T; ++t) { + bool in_valid = (t >= valid_start && t < valid_end); + // Per-layer rsqrt factor for the (b, t, *, l) row. + for (int l = 0; l < L; ++l) { + double sum_sq = 0.0; + for (int d = 0; d < D; ++d) { + int64_t idx = ((static_cast(b) * T + t) * D + d) * L + l; + double v = stacked[idx]; + sum_sq += v * v; + } + double variance = sum_sq / static_cast(D); + float rsq = static_cast(1.0 / std::sqrt(variance + eps)); + + for (int d = 0; d < D; ++d) { + int64_t src_idx = ((static_cast(b) * T + t) * D + d) * L + l; + int64_t dst_idx = (static_cast(b) * T + t) * (D * L) + (d * L + l); + if (in_valid) { + normed[dst_idx] = stacked[src_idx] * rsq; + } else { + normed[dst_idx] = 0.0f; + } + } + } + } + } + } + + // FeatureExtractorV1 block — just wraps the aggregate_embed Linear + // (feature_extractor.aggregate_embed.weight). + // + // The CPU-side normalization lives in feature_extractor_normalize(); this block + // expects an already-normalized [B, T, D*L] tensor as input. + struct FeatureExtractorV1 : public GGMLBlock { + protected: + int64_t flat_dim; + int64_t inner_dim; + + public: + FeatureExtractorV1() = default; + FeatureExtractorV1(int64_t flat_dim, int64_t inner_dim) + : flat_dim(flat_dim), inner_dim(inner_dim) { + // Python: aggregate_embed = Linear(flat_dim, inner_dim, bias=False). + // flat_dim is huge (49 × 3840 = 188160 for 22B); F16 matmul accumulator + // can't hold that many sums at full precision. Match V2's force_prec_f32. + blocks["aggregate_embed"] = std::make_shared(flat_dim, inner_dim, /*bias=*/false, + /*force_f32=*/false, /*force_prec_f32=*/true); + } + + // x: ne [flat_dim, T, B] (already normalized via feature_extractor_normalize). + // returns: ne [inner_dim, T, B]. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto agg = std::dynamic_pointer_cast(blocks["aggregate_embed"]); + return agg->forward(ctx, x); + } + }; + + // FeatureExtractorV2 block — V2 / 22B text path. Two parallel Linears + // (video + optional audio) WITH bias on a per-token RMS-normalized input. + // Python: ltx_core/text_encoders/gemma/feature_extractor.py::FeatureExtractorV2. + // + // The CPU-side normalization lives in feature_extractor_normalize_v2(); this + // block applies the in-graph rescale factor sqrt(target/source_dim) and the + // video_aggregate_embed Linear. Audio path is declared optional — if audio + // weights are absent the block skips it and is video-only. + struct FeatureExtractorV2 : public GGMLBlock { + protected: + int64_t flat_dim; // D * L (Gemma hidden × num_layers) + int64_t source_dim; // Gemma hidden size (D) + int64_t video_out_dim; // DiT inner_dim + int64_t audio_out_dim; // optional; 0 when no audio aggregate_embed + float video_scale; // sqrt(video_out_dim / source_dim) + float audio_scale; // sqrt(audio_out_dim / source_dim) + + public: + FeatureExtractorV2() = default; + FeatureExtractorV2(int64_t flat_dim, int64_t source_dim, + int64_t video_out_dim, + int64_t audio_out_dim = 0) + : flat_dim(flat_dim), source_dim(source_dim), + video_out_dim(video_out_dim), audio_out_dim(audio_out_dim) { + video_scale = std::sqrt(static_cast(video_out_dim) / static_cast(source_dim)); + audio_scale = audio_out_dim > 0 + ? std::sqrt(static_cast(audio_out_dim) / static_cast(source_dim)) + : 0.f; + // Force FP32 matmul precision: flat_dim=188160 sums easily exceed F16's + // mantissa precision and produce direction-rotating drift on later + // tokens. Comfy runs this in BF16 which has full FP32 range; we need + // explicit F32 precision when running on CUDA/F16 backends. + blocks["video_aggregate_embed"] = std::make_shared(flat_dim, video_out_dim, /*bias=*/true, + /*force_f32=*/false, /*force_prec_f32=*/true); + if (audio_out_dim > 0) { + blocks["audio_aggregate_embed"] = std::make_shared(flat_dim, audio_out_dim, /*bias=*/true, + /*force_f32=*/false, /*force_prec_f32=*/true); + } + } + + bool has_audio() const { return audio_out_dim > 0; } + int64_t get_video_out_dim() const { return video_out_dim; } + int64_t get_audio_out_dim() const { return audio_out_dim; } + + // x: ne [flat_dim, T, B] (already per-token RMS-normalized via feature_extractor_normalize_v2). + // Returns video_features ne [video_out_dim, T, B]. Audio branch unused for video-only smoke tests. + ggml_tensor* forward_video(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto agg = std::dynamic_pointer_cast(blocks["video_aggregate_embed"]); + auto scaled = ggml_scale(ctx->ggml_ctx, x, video_scale); + return agg->forward(ctx, scaled); + } + + ggml_tensor* forward_audio(GGMLRunnerContext* ctx, ggml_tensor* x) { + GGML_ASSERT(has_audio() && "FeatureExtractorV2: audio_aggregate_embed not allocated"); + auto agg = std::dynamic_pointer_cast(blocks["audio_aggregate_embed"]); + auto scaled = ggml_scale(ctx->ggml_ctx, x, audio_scale); + return agg->forward(ctx, scaled); + } + }; + + // A single 1D transformer block in the connector. + // Python: _BasicTransformerBlock1D in embeddings_connector.py. + // + // Self-attention only (no cross-attention, no AdaLN). Parameter-free rms_norm + // before attention and before the feed-forward. + struct BasicTransformerBlock1D : public GGMLBlock { + protected: + int64_t dim; + int num_heads; + int head_dim; + bool apply_gated_attention; + float norm_eps; + + public: + BasicTransformerBlock1D() = default; + BasicTransformerBlock1D(int64_t dim, int num_heads, int head_dim, + bool apply_gated_attention = false, + float norm_eps = 1e-6f) + : dim(dim), num_heads(num_heads), head_dim(head_dim), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps) { + // Self-attention: context_dim = query_dim = dim. The connector's 1D RoPE + // uses INTERLEAVED layout (Python embeddings_connector.py calls + // precompute_freqs_cis with default rope_type=INTERLEAVED); only the DiT + // was switched to SPLIT in LTX-2.3. + blocks["attn1"] = std::make_shared(dim, dim, num_heads, head_dim, + apply_gated_attention, norm_eps, + LTX::RopeType::INTERLEAVED); + blocks["ff"] = std::make_shared(dim, dim); + } + + // hidden_states: ne [dim, T, B] + // pe: ne [dim, T, 2] packed cos/sin (or nullptr) + // mask: additive attention mask (or nullptr) + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* pe, + ggml_tensor* mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + // Pre-norm + self-attention + residual. + auto norm1 = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + auto a_out = attn1->forward(ctx, norm1, /*context=*/nullptr, pe, mask); + hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, a_out); + + // Pre-norm + feed-forward + residual. + auto norm2 = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + auto f_out = ff->forward(ctx, norm2); + hidden_states = ggml_add(ctx->ggml_ctx, hidden_states, f_out); + + return hidden_states; + } + }; + + // Embeddings1DConnector: 2-layer 1D transformer with learnable registers + + // final parameter-free rms_norm. 1D RoPE with max_pos=[1], theta=10000.0. + struct Embeddings1DConnector : public GGMLBlock { + protected: + int num_heads; + int head_dim; + int64_t inner_dim; + int num_layers; + int num_registers; // 0 disables the learnable-registers path. + float theta; + std::vector max_pos; + bool apply_gated_attention; + float norm_eps; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string prefix = "") override { + if (num_registers > 0) { + // Python: learnable_registers = Parameter(rand(num_registers, inner_dim) * 2 - 1) + // ggml ne layout: innermost = inner_dim, so [inner_dim, num_registers]. + params["learnable_registers"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, inner_dim, num_registers); + } + } + + public: + Embeddings1DConnector() = default; + Embeddings1DConnector(int num_heads, int head_dim, int num_layers, + int num_registers = 128, + float theta = 10000.0f, + const std::vector& max_pos = {1}, + bool apply_gated_attention = false, + float norm_eps = 1e-6f) + : num_heads(num_heads), head_dim(head_dim), + inner_dim(static_cast(num_heads) * head_dim), + num_layers(num_layers), num_registers(num_registers), + theta(theta), max_pos(max_pos), + apply_gated_attention(apply_gated_attention), norm_eps(norm_eps) { + for (int i = 0; i < num_layers; ++i) { + blocks["transformer_1d_blocks." + std::to_string(i)] = + std::make_shared(inner_dim, num_heads, head_dim, + apply_gated_attention, norm_eps); + } + } + + int64_t get_inner_dim() const { return inner_dim; } + int get_num_registers() const { return num_registers; } + int get_num_layers() const { return num_layers; } + + ggml_tensor* get_learnable_registers() { + auto it = params.find("learnable_registers"); + return it == params.end() ? nullptr : it->second; + } + + std::shared_ptr get_block(int i) { + return std::dynamic_pointer_cast( + blocks["transformer_1d_blocks." + std::to_string(i)]); + } + + // hidden_states: ne [inner_dim, T, B] + // pe: ne [inner_dim, T, 2] packed cos/sin + // mask: additive attention mask (or nullptr) + // + // NOTE: this currently skips `_replace_padded_with_learnable_registers` — + // callers must guarantee the input is already register-substituted (or no + // padding is present). Handling the register replacement in ggml requires + // boolean indexing/scatter semantics that we defer. + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* pe, + ggml_tensor* mask = nullptr) { + for (int i = 0; i < num_layers; ++i) { + auto block = std::dynamic_pointer_cast( + blocks["transformer_1d_blocks." + std::to_string(i)]); + hidden_states = block->forward(ctx, hidden_states, pe, mask); + } + hidden_states = LTX::parameterless_rms_norm(ctx->ggml_ctx, hidden_states, norm_eps); + return hidden_states; + } + }; + + // Which feature-extractor flavor the runner uses. V1 (19B) has a single + // Linear(flat_dim → inner_dim, bias=False) named `aggregate_embed.weight`, with + // CPU pre-norm via _norm_and_concat_padded_batch. V2 (22B) has two parallel + // Linears with bias (`video_aggregate_embed`, `audio_aggregate_embed`) on a + // per-token RMS-normalized input; we currently wire only the video path. + enum class FeatureExtractorVersion { V1, V2 }; + + // Runner that bundles feature_extractor + Embeddings1DConnector (and optionally + // caption_projection for end-to-end parity testing). Used both by the parity + // test (default ctor args match dump_connector.py) and by LTX2GemmaConditioner + // (which passes real-checkpoint prefixes and sets include_caption_projection=false + // because the DiT owns caption_projection). + // + // Input is the already-normalized [B, T, flat_dim] tensor (see + // feature_extractor_normalize[_v2] for the CPU pre-processing). + struct LTX2ConnectorRunner : public GGMLRunner { + int64_t flat_dim; + int64_t connector_inner_dim; + int num_heads; + int head_dim; + int num_layers; + int num_registers; + int64_t caption_channels; + int64_t caption_hidden; + int64_t caption_out; + float theta; + std::vector max_pos; + bool include_caption_projection; + FeatureExtractorVersion fe_version; + int64_t source_dim; // V2 only: Gemma hidden_size used for rescale + + std::string feat_ext_prefix; + std::string connector_prefix; + std::string caption_proj_prefix; + + FeatureExtractorV1 feature_extractor_v1; + FeatureExtractorV2 feature_extractor_v2; + Embeddings1DConnector connector; + LTX::PixArtAlphaTextProjection caption_projection; + + std::vector pe_vec; + + // probe_stage selects the returned tensor. Stages <1 and >2 are shared + // between V1 and V2; 1 and 2 are legacy V1 parity probes (after block 0/1) + // and only work when num_layers >= 2. For V2 (production use), stage 3 + // (final rms_norm) is what the conditioner calls. + // 0 = after feature_extractor (+ graph-side rescale for V2) + // 1 = after connector block 0 + // 2 = after connector block 1 + // 3 = after all blocks + final rms_norm (connector output) + // 4 = after caption_projection (requires include_caption_projection) + int probe_stage = 3; + + // Target sequence length fed into the 1D connector. Python's + // LTXVGemmaTokenizer pads to max_length=1024 so the connector always sees + // 1024 tokens with learnable_registers tiled max_length/num_registers times. + // A value of 0 falls back to num_registers (the old, compact behaviour used + // by the parity dumper). Real inference MUST set this to match the Python + // tokenizer max_length (1024) — see LTX-2 ti2vid pipelines. + int target_seq_len = 0; + void set_target_seq_len(int len) { target_seq_len = len; } + + LTX2ConnectorRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + int64_t flat_dim, + int num_heads, + int head_dim, + int num_layers, + int num_registers, + int64_t caption_channels = 0, + int64_t caption_hidden = 0, + int64_t caption_out = 0, + float theta = 10000.0f, + const std::vector& max_pos = {1}, + const String2TensorStorage& tsm = {}, + bool include_caption_projection = true, + const std::string& feat_ext_prefix = "feature_extractor", + const std::string& connector_prefix = "connector", + const std::string& caption_proj_prefix = "caption_projection", + FeatureExtractorVersion fe_version = FeatureExtractorVersion::V1, + int64_t source_dim = 0, + bool apply_gated_attention = false) + : GGMLRunner(backend, offload_params_to_cpu), + flat_dim(flat_dim), + connector_inner_dim(static_cast(num_heads) * head_dim), + num_heads(num_heads), head_dim(head_dim), num_layers(num_layers), + num_registers(num_registers), + caption_channels(caption_channels), + caption_hidden(caption_hidden), + caption_out(caption_out), + theta(theta), max_pos(max_pos), + include_caption_projection(include_caption_projection), + fe_version(fe_version), + source_dim(source_dim), + feat_ext_prefix(feat_ext_prefix), + connector_prefix(connector_prefix), + caption_proj_prefix(caption_proj_prefix) { + if (fe_version == FeatureExtractorVersion::V2) { + GGML_ASSERT(source_dim > 0 && "FeatureExtractorV2 needs Gemma source_dim for the sqrt-rescale"); + feature_extractor_v2 = FeatureExtractorV2(flat_dim, source_dim, connector_inner_dim); + feature_extractor_v2.init(params_ctx, tsm, feat_ext_prefix); + } else { + feature_extractor_v1 = FeatureExtractorV1(flat_dim, connector_inner_dim); + feature_extractor_v1.init(params_ctx, tsm, feat_ext_prefix); + } + connector = Embeddings1DConnector(num_heads, head_dim, num_layers, + num_registers, theta, max_pos, + apply_gated_attention); + connector.init(params_ctx, tsm, connector_prefix); + if (include_caption_projection) { + caption_projection = LTX::PixArtAlphaTextProjection(caption_channels, caption_hidden, caption_out); + caption_projection.init(params_ctx, tsm, caption_proj_prefix); + } + } + + std::string get_desc() override { return "ltx2-connector"; } + + void get_param_tensors(std::map& tensors, + const std::string /*unused*/ = "") { + if (fe_version == FeatureExtractorVersion::V2) { + feature_extractor_v2.get_param_tensors(tensors, feat_ext_prefix); + } else { + feature_extractor_v1.get_param_tensors(tensors, feat_ext_prefix); + } + connector.get_param_tensors(tensors, connector_prefix); + if (include_caption_projection) { + caption_projection.get_param_tensors(tensors, caption_proj_prefix); + } + } + + // Build the full graph. probe_stage selects the final returned tensor: + // 0: after feature_extractor (shape [connector_inner_dim, T, B]) + // 1: after connector block 0 (V1 parity probe, legacy) + // 2: after connector block 1 (V1 parity probe, legacy) + // 3: after all connector blocks + final rms_norm + // 4: after caption_projection (needs include_caption_projection=true) + ggml_cgraph* build_graph(const sd::Tensor& normed_in) { + ggml_cgraph* gf = new_graph_custom(LTX::LTX_GRAPH_SIZE); + + ggml_tensor* x = make_input(normed_in); // ne [flat_dim, T, B] + int64_t T = x->ne[1]; + + auto runner_ctx = get_context(); + + // Step 1: feature_extractor → [inner_dim, T, B]. + ggml_tensor* feat = nullptr; + if (fe_version == FeatureExtractorVersion::V2) { + feat = feature_extractor_v2.forward_video(&runner_ctx, x); + } else { + feat = feature_extractor_v1.forward(&runner_ctx, x); + } + + // Step 1.5: Pad to the target length by filling the tail with + // learnable_registers (tiled when target > num_registers). + // + // Python reference: `_replace_padded_with_learnable_registers` in + // ltx_core/text_encoders/gemma/embeddings_connector.py. It: + // 1. tiles learnable_registers by (seq_len / num_registers) so the tiled + // buffer covers the whole sequence (seq_len == tokenizer max_length), + // 2. moves real tokens to [0, T_real), + // 3. fills [T_real, seq_len) with tiled_registers[T_real, seq_len). + // + // The caller (conditioner.hpp) already does step 2 on CPU and passes feat + // as [inner_dim, T_real, B]. We pick the target length in this order of + // preference: (a) explicit target_seq_len (set by the conditioner to + // Gemma's max_length), (b) num_registers (legacy/parity default). + // + // Tiling is implemented with a ggml_repeat into a [inner_dim, target, B] + // destination — cheap on GPU and matches torch.tile semantics for the + // innermost tiling axis. + const int num_registers = connector.get_num_registers(); + int64_t target_len = + target_seq_len > 0 ? static_cast(target_seq_len) + : static_cast(num_registers); + if (num_registers > 0 && target_len > 0 && T < target_len) { + GGML_ASSERT(target_len % num_registers == 0 && + "target_seq_len must be a multiple of num_registers " + "(Embeddings1DConnector tiles learnable_registers)."); + auto regs = connector.get_learnable_registers(); // [inner_dim, num_registers] + GGML_ASSERT(regs != nullptr && "learnable_registers not initialized"); + + // Build the tiled registers tensor [inner_dim, target_len] by + // repeating learnable_registers along axis 1. + ggml_tensor* tiled = regs; + if (target_len > num_registers) { + auto repeat_tgt = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, + connector_inner_dim, target_len); + tiled = ggml_repeat(compute_ctx, regs, repeat_tgt); + } + + // Slice rows [T : target_len] along axis 1 to get the padding tail. + auto regs_slice = ggml_ext_slice(compute_ctx, tiled, 1, + static_cast(T), + static_cast(target_len)); // [inner_dim, target-T] + regs_slice = ggml_reshape_3d(compute_ctx, ggml_cont(compute_ctx, regs_slice), + connector_inner_dim, + target_len - T, + 1); + feat = ggml_concat(compute_ctx, feat, regs_slice, 1); // [inner_dim, target, B] + T = target_len; + } + + // Build only the subgraph up to the selected probe stage. The final + // named result is the LAST node added (GGMLRunner::get_compute_graph + // picks `ggml_graph_node(gf, -1)`). + ggml_tensor* out = feat; + if (probe_stage >= 1) { + // Precompute 1D RoPE for connector. + auto positions = LTXRope::gen_1d_positions(static_cast(T)); + pe_vec = LTXRope::precompute_freqs_cis_interleaved(positions, + static_cast(connector_inner_dim), + theta, max_pos); + auto pe = ggml_new_tensor_3d(compute_ctx, GGML_TYPE_F32, connector_inner_dim, T, 2); + set_backend_tensor_data(pe, pe_vec.data()); + + // Stages 1, 2: legacy V1 parity probes — stop after block 0/1. + // Stages 3+: production path — run all blocks and the final rms_norm. + if (probe_stage == 1 || probe_stage == 2) { + int blocks_to_run = probe_stage; // 1 → block 0 only; 2 → blocks 0 and 1 + for (int i = 0; i < blocks_to_run && i < num_layers; ++i) { + out = connector.get_block(i)->forward(&runner_ctx, out, pe, nullptr); + } + } else { + for (int i = 0; i < num_layers; ++i) { + out = connector.get_block(i)->forward(&runner_ctx, out, pe, nullptr); + } + out = LTX::parameterless_rms_norm(compute_ctx, out, 1e-6f); + } + } + if (probe_stage >= 4 && include_caption_projection) { + out = caption_projection.forward(&runner_ctx, out); + } + + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, const sd::Tensor& normed_in, int stage = 4) { + probe_stage = stage; + auto get_graph = [&]() -> ggml_cgraph* { return build_graph(normed_in); }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } + }; + +} // namespace LTXConnector + +#endif // __LTX_CONNECTOR_HPP__ diff --git a/src/ltx_rope.hpp b/src/ltx_rope.hpp new file mode 100644 index 000000000..c058e7452 --- /dev/null +++ b/src/ltx_rope.hpp @@ -0,0 +1,350 @@ +#ifndef __LTX_ROPE_HPP__ +#define __LTX_ROPE_HPP__ + +#include +#include +#include "ggml_extend.hpp" + +namespace LTXRope { + // Generate a log-spaced frequency grid from 1 to theta, scaled by pi/2. + // Returns num_freqs = inner_dim / (2 * n_pos_dims) values. + // + // Python reference: generate_freq_grid_pytorch in ltx_core/model/transformer/rope.py. + // We mirror the fp32 linspace path byte-exactly: torch.linspace(0., 1., N, fp32) + // produces indices computed as `i * (1/(N-1))` in fp32 (start + step*i), so we + // replicate that order of operations rather than `(double)i / (N-1)` which + // differs by ~1 ULP at the tail. That 1-ULP freq drift becomes ~3-5 ULPs in + // the freq value and ~5e-2 cos/sin error once the angle hits 1e5 radians at + // T=8. `pow(theta, v)` is then computed in fp32 (std::powf) to match. + __STATIC_INLINE__ std::vector generate_freq_grid(float theta, + int n_pos_dims, + int inner_dim) { + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + std::vector indices(num_freqs); + // Compute in fp64 then cast. For the video DiT (3D RoPE, max_pos normalizes + // to [0, 1]) fp32 would be fine, but the connector's 1D RoPE uses max_pos=[1] + // so raw integer positions feed into the angle → arguments reach ~2e5 radians + // at T=8. At that scale, fp32 libm `exp(t*log(theta))` drifts ~1 ULP in + // the freq value, cascading to ~5e-2 cos/sin diffs vs the numpy-fp64 reference + // used by the connector dumper (`double_precision_rope=True`). fp64 pow matches + // numpy closely enough to land connector parity at ~2e-3 max_abs. + constexpr double pi_half = 1.57079632679489661923; + double theta_d = static_cast(theta); + for (int i = 0; i < num_freqs; ++i) { + double t = num_freqs == 1 ? 0.0 : static_cast(i) / (num_freqs - 1); + indices[i] = static_cast(std::pow(theta_d, t) * pi_half); + } + return indices; + } + + // Build a 3D indices grid for a video latent of shape (F, H, W). + // + // Mirrors the real LTX-2 pipeline: VideoLatentTools.create_initial_state -> + // get_patch_grid_bounds -> get_pixel_coords (ltx_core/components/patchifiers.py and + // ltx_core/tools.py). Per-axis behaviour: + // latent_coords[axis] = [f, f+1] (integer latent indices per patch) + // pixel_coords[axis] = latent_coords * scale_factors[axis] + // if causal_fix: pixel_coords[0] = clamp(pixel_coords[0] + 1 - scale_factors[0], 0, +) + // positions[0] /= fps (temporal axis only) + // if use_middle_indices_grid: pos = midpoint(start, end); else pos = start + // + // Defaults ({1,1,1}, causal_fix=false, fps=1) preserve the parity-test flow, which + // feeds the Python model the simplified (f, h, w) positions directly. Real inference + // MUST pass scale_factors={8, 32, 32} and causal_fix=true (the LTX-2 VAE scale). + // + // Returns a 3×(F*H*W) matrix with layout [axis][token_idx]. + __STATIC_INLINE__ std::vector> gen_video_positions(int F, + int H, + int W, + bool use_middle_indices_grid = true, + float fps = 1.0f, + const std::vector& scale_factors = {1, 1, 1}, + bool causal_fix = false) { + GGML_ASSERT(fps > 0.0f); + GGML_ASSERT(scale_factors.size() == 3); + int total = F * H * W; + std::vector> pos(3, std::vector(total, 0.f)); + const float s0 = static_cast(scale_factors[0]); + const float s1 = static_cast(scale_factors[1]); + const float s2 = static_cast(scale_factors[2]); + for (int f = 0; f < F; ++f) { + float t_s = static_cast(f) * s0; + float t_e = static_cast(f + 1) * s0; + if (causal_fix) { + const float shift = 1.f - s0; + t_s = std::max(0.f, t_s + shift); + t_e = std::max(0.f, t_e + shift); + } + t_s /= fps; + t_e /= fps; + for (int h = 0; h < H; ++h) { + float h_s = static_cast(h) * s1; + float h_e = static_cast(h + 1) * s1; + for (int w = 0; w < W; ++w) { + float w_s = static_cast(w) * s2; + float w_e = static_cast(w + 1) * s2; + int idx = (f * H + h) * W + w; + if (use_middle_indices_grid) { + pos[0][idx] = (t_s + t_e) * 0.5f; + pos[1][idx] = (h_s + h_e) * 0.5f; + pos[2][idx] = (w_s + w_e) * 0.5f; + } else { + pos[0][idx] = t_s; + pos[1][idx] = h_s; + pos[2][idx] = w_s; + } + } + } + } + return pos; + } + + // Precompute interleaved cos/sin freqs for LTX-2 RoPE. + // positions[axis][token]: fractional-ready float positions, size n_pos_dims * T. + // max_pos: normalisation per axis, e.g. {20, 2048, 2048}. + // Returns a packed [2, T, inner_dim] vector: slice [0] = cos, slice [1] = sin. + __STATIC_INLINE__ std::vector precompute_freqs_cis_interleaved(const std::vector>& positions, + int inner_dim, + float theta = 10000.f, + const std::vector& max_pos = {20, 2048, 2048}) { + int n_pos_dims = static_cast(positions.size()); + GGML_ASSERT(n_pos_dims > 0); + GGML_ASSERT(static_cast(max_pos.size()) == n_pos_dims); + int T = static_cast(positions[0].size()); + + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + int pad_size = inner_dim - (num_freqs * n_pos_dims * 2); + + std::vector freq_grid = generate_freq_grid(theta, n_pos_dims, inner_dim); // [num_freqs] + + std::vector pe(2 * T * inner_dim, 0.f); + // Slice 0 (cos) starts at offset 0, slice 1 (sin) starts at T * inner_dim. + size_t cos_off = 0; + size_t sin_off = static_cast(T) * inner_dim; + + // Initialise the pad region: cos = 1.0, sin = 0.0. + for (int t = 0; t < T; ++t) { + for (int i = 0; i < pad_size; ++i) { + pe[cos_off + static_cast(t) * inner_dim + i] = 1.f; + } + } + + for (int t = 0; t < T; ++t) { + std::vector frac_pos(n_pos_dims); + for (int d = 0; d < n_pos_dims; ++d) { + frac_pos[d] = positions[d][t] / static_cast(max_pos[d]); + } + // Freq layout after flatten is [f * n_pos_dims + d], so pair index p = f*n_pos_dims + d. + // After repeat_interleave(2), each pair p corresponds to slots (2p, 2p+1) in the [pad_size:] region. + // + // Note: compute cos/sin in double precision then cast to float. At high frequencies + // (theta^1 * pi/2 ≈ 15708) times (2*t - 1), the angle reaches hundreds of thousands of + // radians — fp32 argument reduction in std::cosf/sinf loses enough precision to drift + // ~5e-2 from PyTorch's tensor-level cos/sin. Python's torch.cos does the reduction + // against a more precise modulus internally (matching fp64 behavior closely enough). + for (int f = 0; f < num_freqs; ++f) { + for (int d = 0; d < n_pos_dims; ++d) { + double angle = static_cast(freq_grid[f]) * + (static_cast(frac_pos[d]) * 2.0 - 1.0); + float c = static_cast(std::cos(angle)); + float s = static_cast(std::sin(angle)); + int pair_i = f * n_pos_dims + d; + int slot0 = pad_size + 2 * pair_i; + int slot1 = pad_size + 2 * pair_i + 1; + pe[cos_off + static_cast(t) * inner_dim + slot0] = c; + pe[cos_off + static_cast(t) * inner_dim + slot1] = c; + pe[sin_off + static_cast(t) * inner_dim + slot0] = s; + pe[sin_off + static_cast(t) * inner_dim + slot1] = s; + } + } + } + return pe; + } + + // Apply LTX-2 interleaved rotary embedding to x. + // x: [inner_dim, T, B] (ggml ne order; logical shape [B, T, inner_dim]) + // cos, sin: [inner_dim, T, 1] (broadcast across batch) + // Returns x rotated, same shape as x. + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb_interleaved(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* cos_freq, + ggml_tensor* sin_freq) { + int64_t inner_dim = x->ne[0]; + int64_t T = x->ne[1]; + int64_t B = x->ne[2]; + GGML_ASSERT(inner_dim % 2 == 0); + + // Reshape to pairs: [2, inner_dim/2, T, B]. + auto x_pairs = ggml_reshape_4d(ctx, x, 2, inner_dim / 2, T, B); + + // Views: x_even (offset 0) and x_odd (offset nb[0]) each shape [1, inner_dim/2, T, B]. + auto x_even = ggml_view_4d(ctx, x_pairs, 1, inner_dim / 2, T, B, + x_pairs->nb[1], x_pairs->nb[2], x_pairs->nb[3], 0); + auto x_odd = ggml_view_4d(ctx, x_pairs, 1, inner_dim / 2, T, B, + x_pairs->nb[1], x_pairs->nb[2], x_pairs->nb[3], x_pairs->nb[0]); + x_even = ggml_cont(ctx, x_even); + x_odd = ggml_cont(ctx, x_odd); + + // Rotated pair (−x_odd, x_even) → concat along dim 0 → [2, inner_dim/2, T, B]. + auto neg_x_odd = ggml_scale(ctx, x_odd, -1.f); + auto rotated = ggml_concat(ctx, neg_x_odd, x_even, 0); + rotated = ggml_reshape_3d(ctx, rotated, inner_dim, T, B); + + // out = x * cos + rotated * sin + auto out = ggml_add(ctx, ggml_mul(ctx, x, cos_freq), ggml_mul(ctx, rotated, sin_freq)); + return out; + } + + // Precompute SPLIT cos/sin freqs for LTX-2.3 DiT. Python reference: + // `precompute_freqs_cis(..., rope_type=LTXRopeType.SPLIT)`. + // - Unlike the interleaved variant, freqs are NOT repeat_interleaved; each of + // the inner_dim/2 frequencies is broadcast once across the corresponding + // position in the first AND second halves of head_dim. + // - cos/sin are reshaped to per-head: shape [B, T, H, head_dim/2]. + // - We pack both into a single buffer of ne [head_dim/2, num_heads, T, 2] + // (slice 0 = cos, slice 1 = sin), matching the interleaved helper's + // single-buffer convention. split_pe_split() below slices that back. + // + // freqs flattened length is num_freqs * n_pos_dims; when it's less than + // inner_dim/2, the leading (pad_size) slots are filled cos=1, sin=0, matching + // Python's `split_freqs_cis`. + __STATIC_INLINE__ std::vector precompute_freqs_cis_split(const std::vector>& positions, + int inner_dim, + int num_heads, + float theta = 10000.f, + const std::vector& max_pos = {20, 2048, 2048}) { + int n_pos_dims = static_cast(positions.size()); + GGML_ASSERT(n_pos_dims > 0); + GGML_ASSERT(static_cast(max_pos.size()) == n_pos_dims); + GGML_ASSERT(inner_dim % (2 * num_heads) == 0); + int T = static_cast(positions[0].size()); + int half_dim = inner_dim / 2; // per-token freq count + int head_dim2 = half_dim / num_heads; // per-head freq count + + int n_elem = 2 * n_pos_dims; + int num_freqs = inner_dim / n_elem; + int current = num_freqs * n_pos_dims; // pre-pad flat freq count + int pad_size = half_dim - current; + GGML_ASSERT(pad_size >= 0); + + std::vector freq_grid = generate_freq_grid(theta, n_pos_dims, inner_dim); + + // Output layout (ne): [head_dim/2, num_heads, T, 2]. Flat index: + // (slice=cos/sin)*T*num_heads*head_dim2 + t*num_heads*head_dim2 + h*head_dim2 + k + std::vector pe(2 * T * num_heads * head_dim2, 0.f); + size_t cos_off = 0; + size_t sin_off = static_cast(T) * num_heads * head_dim2; + + // Pad region (first `pad_size` columns of the per-token freq vector): cos=1, sin=0. + // Per-head reshape means pad_size slots at the start of the head-major flat + // vector. Since cos/sin for a token are stored as [h=0 head_dim2, h=1 head_dim2, …], + // the pad falls in the first pad_size consecutive positions across the head groups. + for (int t = 0; t < T; ++t) { + for (int p = 0; p < pad_size; ++p) { + int h = p / head_dim2; + int k = p % head_dim2; + size_t dst = static_cast(t) * num_heads * head_dim2 + h * head_dim2 + k; + pe[cos_off + dst] = 1.f; + pe[sin_off + dst] = 0.f; + } + } + + constexpr double pi_half = 1.57079632679489661923; + (void)pi_half; + for (int t = 0; t < T; ++t) { + std::vector frac_pos(n_pos_dims); + for (int d = 0; d < n_pos_dims; ++d) { + frac_pos[d] = positions[d][t] / static_cast(max_pos[d]); + } + // Non-pad slots start at column `pad_size` in the flat per-token freq vector. + // Python layout: freqs = (indices * (fractional*2-1)).transpose(-1,-2).flatten(2). + // With indices shape [num_freqs] and fractional [n_pos_dims], after broadcast + // and transpose the order is [f * n_pos_dims + d]. Slot index in the padded + // per-token vector = pad_size + f*n_pos_dims + d. + for (int f = 0; f < num_freqs; ++f) { + for (int d = 0; d < n_pos_dims; ++d) { + double angle = static_cast(freq_grid[f]) * + (static_cast(frac_pos[d]) * 2.0 - 1.0); + float c = static_cast(std::cos(angle)); + float s = static_cast(std::sin(angle)); + int flat_slot = pad_size + f * n_pos_dims + d; + int h = flat_slot / head_dim2; + int k = flat_slot % head_dim2; + size_t dst = static_cast(t) * num_heads * head_dim2 + h * head_dim2 + k; + pe[cos_off + dst] = c; + pe[sin_off + dst] = s; + } + } + } + return pe; + } + + // Split-half rotary embedding. Python: apply_split_rotary_emb. + // first_half = x[..., 0:head_dim/2] + // second_half = x[..., head_dim/2:head_dim] + // out = concat(first*cos - second*sin, second*cos + first*sin, dim=last) + // Operates per-head. x ne=[inner_dim, T, B]; pe tensors (cos/sin) ne=[head_dim/2, num_heads, T, 1]. + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb_split(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* cos_freq, + ggml_tensor* sin_freq, + int num_heads) { + int64_t inner_dim = x->ne[0]; + int64_t T = x->ne[1]; + int64_t B = x->ne[2]; + GGML_ASSERT(inner_dim % (2 * num_heads) == 0); + int64_t head_dim = inner_dim / num_heads; + int64_t half = head_dim / 2; + + // Reshape x [inner_dim, T, B] → [head_dim, num_heads, T, B], then split halves. + auto x4 = ggml_reshape_4d(ctx, x, head_dim, num_heads, T, B); + + // first_half view: offset 0, shape [half, num_heads, T, B]. + auto first = ggml_view_4d(ctx, x4, half, num_heads, T, B, + x4->nb[1], x4->nb[2], x4->nb[3], 0); + // second_half view: offset = half * sizeof(el). + auto second = ggml_view_4d(ctx, x4, half, num_heads, T, B, + x4->nb[1], x4->nb[2], x4->nb[3], half * x4->nb[0]); + first = ggml_cont(ctx, first); + second = ggml_cont(ctx, second); + + // cos/sin ne [half, num_heads, T, 1] broadcast on B axis with first/second [half, num_heads, T, B]. + auto first_out = ggml_sub(ctx, ggml_mul(ctx, first, cos_freq), + ggml_mul(ctx, second, sin_freq)); + auto second_out = ggml_add(ctx, ggml_mul(ctx, second, cos_freq), + ggml_mul(ctx, first, sin_freq)); + + // Re-concat along dim 0 (head_dim) → [head_dim, num_heads, T, B]. + auto joined = ggml_concat(ctx, first_out, second_out, 0); + joined = ggml_reshape_3d(ctx, joined, inner_dim, T, B); + return joined; + } + + // Slice a packed split pe buffer of ne [half, num_heads, T, 2] into cos (slice 0) + // and sin (slice 1) views, each ne=[half, num_heads, T, 1]. + __STATIC_INLINE__ std::pair split_pe_split(ggml_context* ctx, ggml_tensor* pe) { + int64_t half = pe->ne[0]; + int64_t num_heads = pe->ne[1]; + int64_t T = pe->ne[2]; + auto cos_freq = ggml_view_4d(ctx, pe, half, num_heads, T, 1, + pe->nb[1], pe->nb[2], pe->nb[3], 0); + auto sin_freq = ggml_view_4d(ctx, pe, half, num_heads, T, 1, + pe->nb[1], pe->nb[2], pe->nb[3], pe->nb[3]); + return {cos_freq, sin_freq}; + } + + // Convenience: split a packed [2, T, inner_dim] pe tensor (slice 0 = cos, slice 1 = sin) + // into two views usable as cos/sin operands. + __STATIC_INLINE__ std::pair split_pe(ggml_context* ctx, ggml_tensor* pe) { + // pe: [inner_dim, T, 2] in ggml ne order. + int64_t inner_dim = pe->ne[0]; + int64_t T = pe->ne[1]; + auto cos_freq = ggml_view_3d(ctx, pe, inner_dim, T, 1, pe->nb[1], pe->nb[2], 0); + auto sin_freq = ggml_view_3d(ctx, pe, inner_dim, T, 1, pe->nb[1], pe->nb[2], pe->nb[2]); + return {cos_freq, sin_freq}; + } +}; // namespace LTXRope + +#endif // __LTX_ROPE_HPP__ diff --git a/src/ltxv.hpp b/src/ltxv.hpp index fb37dbe02..0e493443d 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -5,9 +5,13 @@ namespace LTXV { + enum class SpatialPadding { ZEROS, REFLECT }; + class CausalConv3d : public GGMLBlock { protected: int time_kernel_size; + int spatial_kernel_size; + SpatialPadding spatial_padding; public: CausalConv3d(int64_t in_channels, @@ -15,52 +19,98 @@ namespace LTXV { int kernel_size = 3, std::tuple stride = {1, 1, 1}, int dilation = 1, - bool bias = true) { - time_kernel_size = kernel_size / 2; - blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, + bool bias = true, + SpatialPadding padding_mode = SpatialPadding::ZEROS) { + // Python reference: self.time_kernel_size = kernel_size[0] — the full temporal kernel. + // Earlier revisions of this file used `kernel_size / 2` which under-padded by a factor of 2 for k>=3 + // and padded 1 frame when k=1/2 where no padding was expected. Match Python verbatim. + time_kernel_size = kernel_size; + spatial_kernel_size = kernel_size; + spatial_padding = padding_mode; + // When using reflect padding we do it manually in forward(), so the inner Conv3d + // must run with spatial padding=0. For zeros mode the Conv3d handles padding itself. + int conv_pad_hw = (padding_mode == SpatialPadding::ZEROS) ? (kernel_size / 2) : 0; + blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, out_channels, {kernel_size, kernel_size, kernel_size}, stride, - {0, kernel_size / 2, kernel_size / 2}, + {0, conv_pad_hw, conv_pad_hw}, {dilation, 1, 1}, bias)); } + // Helper: replicate the given single-frame tensor `count` times along the depth axis. + // Returns a [IW, IH, count, N*IC] tensor. count must be >= 1. + static ggml_tensor* repeat_frame(ggml_context* ctx, ggml_tensor* frame, int count) { + auto out = frame; + for (int i = 1; i < count; i++) { + out = ggml_concat(ctx, out, frame, 2); + } + return out; + } + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { - // x: [N*IC, ID, IH, IW] - // result: [N*OC, OD, OH, OW] - auto conv = std::dynamic_pointer_cast(blocks["conv"]); + // x logical shape: [N*IC, ID, IH, IW] (Python order); ggml ne: [IW, IH, ID, N*IC] + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto ggml_cx = ctx->ggml_ctx; + + int pad_front = 0; + int pad_back = 0; if (causal) { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < time_kernel_size - 1; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } - x = ggml_concat(ctx, first_frame_pad, x, 2); + pad_front = time_kernel_size - 1; } else { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - int64_t offset = h->nb[2] * h->ne[2]; + pad_front = (time_kernel_size - 1) / 2; + pad_back = (time_kernel_size - 1) / 2; + } - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } + if (pad_front > 0 || pad_back > 0) { + // Extract first frame as a [IW, IH, 1, N*IC] view on x along the depth axis (ne[2]). + auto first_frame = ggml_view_4d(ggml_cx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + first_frame = ggml_cont(ggml_cx, first_frame); - auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW] - last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW] - auto last_frame_pad = last_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2); + if (pad_front > 0) { + auto front_pad = repeat_frame(ggml_cx, first_frame, pad_front); + x = ggml_concat(ggml_cx, front_pad, x, 2); + } + if (pad_back > 0) { + auto last_frame = ggml_view_4d(ggml_cx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], (x->ne[2] - 1) * x->nb[2]); + last_frame = ggml_cont(ggml_cx, last_frame); + auto back_pad = repeat_frame(ggml_cx, last_frame, pad_back); + x = ggml_concat(ggml_cx, x, back_pad, 2); } + } - x = ggml_concat(ctx, first_frame_pad, x, 2); - x = ggml_concat(ctx, x, last_frame_pad, 2); + // Spatial reflect padding (H, W by k/2 each side). nn.Conv3d with padding_mode='reflect' + // mirrors the edge rows/cols: [a,b,c,d] with pad=1 → [b,a,b,c,d,c]. + if (spatial_padding == SpatialPadding::REFLECT) { + int pad = spatial_kernel_size / 2; + if (pad > 0) { + GGML_ASSERT(pad == 1 && "reflect padding only implemented for kernel=3 (pad=1)"); + x = ggml_cont(ggml_cx, x); + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + // H-axis reflect: top = row 1, bottom = row H-2. + auto row_top = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, W, 1, T, C, + x->nb[1], x->nb[2], x->nb[3], 1 * x->nb[1])); + auto row_bot = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, W, 1, T, C, + x->nb[1], x->nb[2], x->nb[3], (H - 2) * x->nb[1])); + x = ggml_concat(ggml_cx, row_top, x, 1); + x = ggml_concat(ggml_cx, x, row_bot, 1); + x = ggml_cont(ggml_cx, x); + W = x->ne[0]; H = x->ne[1]; T = x->ne[2]; C = x->ne[3]; + // W-axis reflect: left = col 1, right = col W-2. + auto col_left = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, 1, H, T, C, + x->nb[1], x->nb[2], x->nb[3], 1 * x->nb[0])); + auto col_right = ggml_cont(ggml_cx, ggml_view_4d(ggml_cx, x, 1, H, T, C, + x->nb[1], x->nb[2], x->nb[3], (W - 2) * x->nb[0])); + x = ggml_concat(ggml_cx, col_left, x, 0); + x = ggml_concat(ggml_cx, x, col_right, 0); + } } x = conv->forward(ctx, x); diff --git a/src/ltxvae.hpp b/src/ltxvae.hpp new file mode 100644 index 000000000..01e01116d --- /dev/null +++ b/src/ltxvae.hpp @@ -0,0 +1,933 @@ +#ifndef __LTXVAE_HPP__ +#define __LTXVAE_HPP__ + +#include "common_block.hpp" +#include "ltxv.hpp" // CausalConv3d +#include "ltxvae_primitives.hpp" // space/depth, pixel_norm, pcs_* +#include "vae.hpp" // VAE base class + +// LTX-2 video VAE. Companion to src/ltxvae_primitives.hpp (pure ggml ops) — +// this file adds the parameterized composition blocks (ResnetBlock3D, +// UNetMidBlock3D, SpaceToDepthDownsample, DepthToSpaceUpsample) and the +// VideoEncoder / VideoDecoder top-levels. +// +// Tensor convention throughout: B=1 collapsed; ggml ne=[W, H, T, C]. +// Weight naming mirrors the Python reference verbatim — see +// `/tmp/vae_ref/tensor_names.txt` for the canonical prefix layout. + +namespace LTXVAE { + + // ---------- TimestepEmbedder ---------- + // + // PixArtAlphaCombinedTimestepSizeEmbeddings with size_emb_dim=0. Python + // structure: `.timestep_embedder.linear_{1,2}`. Sinusoidal projection into + // TIME_PROJ_DIM (256) fed to a two-Linear MLP with SiLU between. + + struct TimestepEmbedder : public GGMLBlock { + protected: + int embedding_dim = 0; + static constexpr int TIME_PROJ_DIM = 256; + + public: + TimestepEmbedder() = default; + TimestepEmbedder(int embedding_dim) : embedding_dim(embedding_dim) { + blocks["timestep_embedder.linear_1"] = std::make_shared(TIME_PROJ_DIM, embedding_dim, true); + blocks["timestep_embedder.linear_2"] = std::make_shared(embedding_dim, embedding_dim, true); + } + + // timestep: ne=[B]. Returns ne=[embedding_dim, B]. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { + auto l1 = std::dynamic_pointer_cast(blocks["timestep_embedder.linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["timestep_embedder.linear_2"]); + auto proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, TIME_PROJ_DIM, 10000, 1.0f); + auto h = l1->forward(ctx, proj); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + return l2->forward(ctx, h); + } + }; + + // ---------- ResnetBlock3D ---------- + // + // Python forward (when timestep_conditioning=True): + // h = norm1(x) [PixelNorm] + // ada = scale_shift_table + time_embed.reshape(B, 4, in_channels, 1,1,1) + // shift1, scale1, shift2, scale2 = ada.unbind(dim=1) + // h = h * (1 + scale1) + shift1 + // h = silu(h); h = conv1(h) + // h = norm2(h); h = h * (1 + scale2) + shift2 + // h = silu(h); h = conv2(h) + // return input + h + // + // When in_channels != out_channels, the skip path goes through + // norm3 = GroupNorm(num_groups=1, ...) + conv_shortcut (1×1×1 Conv3d). + // Our parity config keeps in==out, so we hard-disable that path until + // we land a use case that needs it. + // + // inject_noise is not yet supported (would require a seeded randn in ggml). + + struct ResnetBlock3D : public GGMLBlock { + protected: + int in_channels = 0; + int out_channels = 0; + bool timestep_conditioning = false; + bool has_shortcut = false; + float eps = 1e-6f; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + if (timestep_conditioning) { + // Python ne: [4, in_channels] → GGML ne [in_channels, 4]. + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, in_channels, 4); + } + } + + public: + ResnetBlock3D() = default; + ResnetBlock3D(int in_ch, int out_ch, bool timestep_cond, float eps_ = 1e-8f, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), + out_channels(out_ch), + timestep_conditioning(timestep_cond), + has_shortcut(in_ch != out_ch), + eps(eps_) { + blocks["conv1"] = std::make_shared( + in_ch, out_ch, 3, std::tuple{1,1,1}, 1, true, pad); + blocks["conv2"] = std::make_shared( + out_ch, out_ch, 3, std::tuple{1,1,1}, 1, true, pad); + if (has_shortcut) { + GGML_ABORT("ResnetBlock3D with in != out not yet implemented (norm3 + conv_shortcut)"); + } + } + + // x: ne=[W, H, T, C_in]. time_embed (optional): ne=[4*in_channels, B=1]. + // `causal` propagates to the inner CausalConv3d.forward calls. + // If traces is non-null, pushes intermediates in order: + // 0 post_norm1, 1 shift1, 2 scale1, 3 post_adaln1, 4 post_conv1, + // 5 post_norm2, 6 shift2, 7 scale2, 8 post_adaln2, 9 post_conv2, 10 final. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* time_embed = nullptr, + std::vector* traces = nullptr, bool causal = true) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + auto input = x; + auto h = pixel_norm(ctx->ggml_ctx, x, eps); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + ggml_tensor *shift1 = nullptr, *scale1 = nullptr; + ggml_tensor *shift2 = nullptr, *scale2 = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(time_embed != nullptr); + auto sst = params["scale_shift_table"]; // ne [in_channels, 4] + // time_embed has ne [4*in_channels, B=1]. Reshape to ne [in_channels, 4] (implicit B=1). + auto te = ggml_reshape_2d(ctx->ggml_ctx, time_embed, in_channels, 4); + auto ada = ggml_add(ctx->ggml_ctx, te, sst); // [in_channels, 4] + + shift1 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 0, 1); + scale1 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 1, 2); + shift2 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 2, 3); + scale2 = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 3, 4); + if (traces) { + traces->push_back(ggml_cont(ctx->ggml_ctx, shift1)); + traces->push_back(ggml_cont(ctx->ggml_ctx, scale1)); + } + // Reshape happens below; the apply also happens below. + // Reshape each [in_channels, 1] → [1, 1, 1, in_channels] so they broadcast + // over (W, H, T) when added/multiplied with h [W, H, T, in_channels]. + shift1 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift1), 1, 1, 1, in_channels); + scale1 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale1), 1, 1, 1, in_channels); + shift2 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift2), 1, 1, 1, in_channels); + scale2 = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale2), 1, 1, 1, in_channels); + + auto h_scaled = ggml_mul(ctx->ggml_ctx, h, scale1); + h = ggml_add(ctx->ggml_ctx, h, h_scaled); + h = ggml_add(ctx->ggml_ctx, h, shift1); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + } + + h = ggml_silu(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h, causal); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + h = pixel_norm(ctx->ggml_ctx, h, eps); + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, h)); + + if (timestep_conditioning) { + auto h_scaled = ggml_mul(ctx->ggml_ctx, h, scale2); + h = ggml_add(ctx->ggml_ctx, h, h_scaled); + h = ggml_add(ctx->ggml_ctx, h, shift2); + } + + h = ggml_silu(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h, causal); + + // in_channels == out_channels so skip is Identity (the `has_shortcut` path aborts above). + return ggml_add(ctx->ggml_ctx, h, input); + } + }; + + // ---------- UNetMidBlock3D ---------- + + struct UNetMidBlock3D : public GGMLBlock { + protected: + int in_channels = 0; + int num_layers = 0; + bool timestep_conditioning = false; + + public: + UNetMidBlock3D() = default; + UNetMidBlock3D(int in_ch, int num_layers, bool timestep_cond, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), num_layers(num_layers), timestep_conditioning(timestep_cond) { + for (int i = 0; i < num_layers; i++) { + blocks["res_blocks." + std::to_string(i)] = std::make_shared(in_ch, in_ch, timestep_cond, 1e-8f, pad); + } + if (timestep_cond) { + blocks["time_embedder"] = std::make_shared(in_ch * 4); + } + } + + // timestep: ne=[B=1] if conditioning enabled, else null. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* timestep = nullptr, + std::vector* traces = nullptr, bool causal = true) { + ggml_tensor* time_embed = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto te = std::dynamic_pointer_cast(blocks["time_embedder"]); + time_embed = te->forward(ctx, timestep); // ne=[4*in_channels, 1] + if (traces) traces->push_back(ggml_cont(ctx->ggml_ctx, time_embed)); + } + for (int i = 0; i < num_layers; i++) { + auto res = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); + x = res->forward(ctx, x, time_embed, traces, causal); + } + return x; + } + }; + + // ---------- SpaceToDepthDownsample (encoder) ---------- + // + // Python forward: + // if stride[0]==2: x = cat([x[:,:,:1], x], dim=2) # duplicate first frame + // x_in = rearrange(x, "b c (d p1)(h p2)(w p3) -> b (c p1 p2 p3) d h w", ...) + // x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=group_size).mean(dim=2) + // x = self.conv(x, causal); x = rearrange(x, ...s2d...); return x + x_in + + struct SpaceToDepthDownsample : public GGMLBlock { + protected: + int in_channels = 0; + int out_channels = 0; + int p1 = 1, p2 = 1, p3 = 1; + int group_size = 1; + + // Helper: collapse group-size consecutive channels via mean along axis 2 + // (after reshaping [W,H,T,C_exp] → [W*H, T, g, C_new]). + ggml_tensor* group_mean_channel(ggml_context* ctx, ggml_tensor* x, int g) const { + if (g == 1) return x; + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], Cexp = x->ne[3]; + GGML_ASSERT(Cexp % g == 0); + int64_t C_new = Cexp / g; + // Reshape: merge W,H; split Cexp into [g, C_new] (g innermost-of-that-group, + // matching einops "(c g)" with g innermost). + auto y = ggml_reshape_4d(ctx, x, W * H, T, g, C_new); + // Move g to innermost (axis 0) for ggml_mean. + y = ggml_cont(ctx, ggml_permute(ctx, y, 1, 2, 0, 3)); // ne=[g, W*H, T, C_new] + y = ggml_mean(ctx, y); // ne=[1, W*H, T, C_new] + // Permute back & reshape to [W, H, T, C_new]. + y = ggml_cont(ctx, ggml_permute(ctx, y, 3, 0, 1, 2)); // ne=[W*H, T, C_new, 1] + y = ggml_reshape_4d(ctx, y, W, H, T, C_new); + return y; + } + + public: + SpaceToDepthDownsample() = default; + SpaceToDepthDownsample(int in_ch, int out_ch, std::tuple stride) + : in_channels(in_ch), out_channels(out_ch), + p1(std::get<0>(stride)), p2(std::get<1>(stride)), p3(std::get<2>(stride)) { + int prod = p1 * p2 * p3; + GGML_ASSERT((in_ch * prod) % out_ch == 0); + group_size = in_ch * prod / out_ch; + GGML_ASSERT(out_ch % prod == 0); + blocks["conv"] = std::make_shared(in_ch, out_ch / prod, 3); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + + // Duplicate first frame if temporal stride is 2. + if (p1 == 2) { + auto first = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + first = ggml_cont(ctx->ggml_ctx, first); + x = ggml_concat(ctx->ggml_ctx, first, x, 2); // prepend along T + } + + // Skip: s2d → group-mean. + auto x_in = space_to_depth(ctx->ggml_ctx, x, p1, p2, p3); + x_in = group_mean_channel(ctx->ggml_ctx, x_in, group_size); + + // Main: conv (preserves T because of causal padding, stride=1), then s2d. + auto y = conv->forward(ctx, x, causal); + y = space_to_depth(ctx->ggml_ctx, y, p1, p2, p3); + + return ggml_add(ctx->ggml_ctx, y, x_in); + } + }; + + // ---------- DepthToSpaceUpsample (decoder) ---------- + // + // For the parity test we only need residual=False (compress_time, compress_space). + // `compress_all` blocks with residual=True have a repeat-based skip path; we'll + // add that when a decoder config needs it. + + struct DepthToSpaceUpsample : public GGMLBlock { + protected: + int in_channels = 0; + int p1 = 1, p2 = 1, p3 = 1; + int reduction_factor = 1; + + public: + DepthToSpaceUpsample() = default; + DepthToSpaceUpsample(int in_ch, std::tuple stride, int reduction_factor = 1, + LTXV::SpatialPadding pad = LTXV::SpatialPadding::ZEROS) + : in_channels(in_ch), + p1(std::get<0>(stride)), p2(std::get<1>(stride)), p3(std::get<2>(stride)), + reduction_factor(reduction_factor) { + int prod = p1 * p2 * p3; + int conv_out = prod * in_ch / reduction_factor; + blocks["conv"] = std::make_shared( + in_ch, conv_out, 3, std::tuple{1,1,1}, 1, true, pad); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + x = conv->forward(ctx, x, causal); + x = depth_to_space(ctx->ggml_ctx, x, p1, p2, p3); + if (p1 == 2) { + // Drop first frame along T to match Python x[:, :, 1:, ...]. + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + auto sliced = ggml_view_4d(ctx->ggml_ctx, x, + W, H, T - 1, C, + x->nb[1], x->nb[2], x->nb[3], x->nb[2]); // skip frame 0 + x = ggml_cont(ctx->ggml_ctx, sliced); + } + return x; + } + }; + + // ---------- PerChannelStatistics wrapper ---------- + // + // Python uses register_buffer("std-of-means", ...) and ("mean-of-means", ...) — + // dashed names which don't appear elsewhere in this codebase. We register them + // as tensors via init_params and carry the dashed names verbatim so loader + // name matching finds them. + + struct PerChannelStatisticsBlock : public GGMLBlock { + protected: + int latent_channels = 0; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + params["std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, latent_channels); + params["mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, latent_channels); + } + + public: + PerChannelStatisticsBlock() = default; + explicit PerChannelStatisticsBlock(int latent_channels) : latent_channels(latent_channels) {} + + ggml_tensor* normalize(GGMLRunnerContext* ctx, ggml_tensor* x) { + return pcs_normalize(ctx->ggml_ctx, x, params["mean-of-means"], params["std-of-means"]); + } + ggml_tensor* un_normalize(GGMLRunnerContext* ctx, ggml_tensor* x) { + return pcs_unnormalize(ctx->ggml_ctx, x, params["mean-of-means"], params["std-of-means"]); + } + }; + + // ---------- VideoEncoder ---------- + // + // The encoder config is a list of (block_name, block_config) tuples — we keep + // that shape in C++ via an EncoderBlockSpec. Only `res_x`, `compress_space_res`, + // `compress_time_res`, `compress_all_res` are handled here; more variants can + // be added as their use-cases land. `norm_layer` is pixel_norm only (group_norm + // would require new primitives). `latent_log_var` is UNIFORM only. + + enum class EncoderBlockKind { + RES_X, + COMPRESS_SPACE_RES, // stride=(1,2,2) + COMPRESS_TIME_RES, // stride=(2,1,1) + COMPRESS_ALL_RES, // stride=(2,2,2) + }; + + struct EncoderBlockSpec { + EncoderBlockKind kind; + int num_layers = 1; // used for RES_X + int multiplier = 2; // used for compress_*_res + }; + + struct VideoEncoder : public GGMLBlock { + protected: + int in_channels = 3; + int latent_channels = 128; + int patch_size = 4; + std::vector encoder_blocks; + float eps = 1e-6f; + + public: + VideoEncoder() = default; + VideoEncoder(int in_ch, int latent_ch, int patch, + const std::vector& enc_blocks) + : in_channels(in_ch), latent_channels(latent_ch), patch_size(patch), + encoder_blocks(enc_blocks) { + int feature_ch = latent_ch; + int cur_in = in_ch * patch * patch; // after patchify + + blocks["conv_in"] = std::make_shared(cur_in, feature_ch, 3); + + int cur_c = feature_ch; + for (size_t i = 0; i < encoder_blocks.size(); ++i) { + const auto& b = encoder_blocks[i]; + std::string key = "down_blocks." + std::to_string(i); + switch (b.kind) { + case EncoderBlockKind::RES_X: + blocks[key] = std::make_shared(cur_c, b.num_layers, /*timestep_cond=*/false); + break; + case EncoderBlockKind::COMPRESS_SPACE_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{1,2,2}); + cur_c *= b.multiplier; + break; + case EncoderBlockKind::COMPRESS_TIME_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{2,1,1}); + cur_c *= b.multiplier; + break; + case EncoderBlockKind::COMPRESS_ALL_RES: + blocks[key] = std::make_shared(cur_c, cur_c * b.multiplier, std::tuple{2,2,2}); + cur_c *= b.multiplier; + break; + } + } + + // UNIFORM log-var: conv_out gets one extra channel for the shared logvar. + int conv_out_ch = latent_ch + 1; + blocks["conv_out"] = std::make_shared(cur_c, conv_out_ch, 3); + blocks["per_channel_statistics"] = std::make_shared(latent_ch); + } + + // sample: ne=[W, H, T, C=3] (B=1). Returns normalized latent ne=[W', H', T', latent_ch]. + // If trace_outputs is non-null, intermediates are pushed in this order: + // 0: post_patchify, 1: post_conv_in, 2..K-1: per down_block output, + // K: post_norm, K+1: post_conv_out, K+2: means_preNorm, K+3: latent. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* sample, + std::vector* trace_outputs = nullptr) { + // patchify (distinct channel ordering from the SpaceToDepthDownsample blocks; + // see `patchify` comment in ltxvae_primitives.hpp). + auto x = patchify(ctx->ggml_ctx, sample, 1, patch_size, patch_size); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + x = conv_in->forward(ctx, x); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + for (size_t i = 0; i < encoder_blocks.size(); ++i) { + std::string key = "down_blocks." + std::to_string(i); + switch (encoder_blocks[i].kind) { + case EncoderBlockKind::RES_X: { + auto b = std::dynamic_pointer_cast(blocks[key]); + x = b->forward(ctx, x, nullptr); + break; + } + default: { + auto b = std::dynamic_pointer_cast(blocks[key]); + x = b->forward(ctx, x); + break; + } + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + } + + x = pixel_norm(ctx->ggml_ctx, x, eps); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + x = ggml_silu(ctx->ggml_ctx, x); + + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + x = conv_out->forward(ctx, x); // ne=[W', H', T', latent_ch+1] + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + // UNIFORM log_var handling: means = x[:, :-1], we skip logvar entirely (it would be + // expanded then discarded after the chunk(2) split). Take the first latent_ch channels. + auto means = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], x->ne[2], latent_channels, + x->nb[1], x->nb[2], x->nb[3], 0); + means = ggml_cont(ctx->ggml_ctx, means); + if (trace_outputs) trace_outputs->push_back(means); + + auto pcs = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto latent = pcs->normalize(ctx, means); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, latent)); + return latent; + } + }; + + // ---------- VideoDecoder ---------- + + enum class DecoderBlockKind { + RES_X, + COMPRESS_SPACE, // stride=(1,2,2), residual=False + COMPRESS_TIME, // stride=(2,1,1), residual=False + COMPRESS_ALL, // stride=(2,2,2), residual configurable (default False here) + }; + + struct DecoderBlockSpec { + DecoderBlockKind kind; + int num_layers = 1; // RES_X + int multiplier = 1; // channel reduction factor for compress_* + // Per-res_block timestep conditioning. Defaults to true (older + // configs assumed all RES_X blocks were timestep-conditioned), but + // the real LTX-2 22B VAE only conditions the inner res_blocks; the + // outer ones lack scale_shift_table + time_embedder weights. + bool timestep_cond = true; + }; + + struct VideoDecoder : public GGMLBlock { + protected: + int latent_channels = 128; + int out_channels = 3; + int patch_size = 4; + int base_channels = 128; + bool timestep_conditioning = true; + std::vector decoder_blocks; // stored in ENCODER-side order; forward reverses + float eps = 1e-6f; + int feature_channels = 0; + // Decoder uses `reflect` spatial padding by default per the Python reference + // (VideoDecoderConfigurator.from_config default). All CausalConv3d instances we + // construct below are handed this padding mode. + static constexpr LTXV::SpatialPadding PAD = LTXV::SpatialPadding::REFLECT; + // Python configurator defaults: `causal_decoder=False`. All our CausalConv3d.forward + // calls within the decoder should therefore use causal=False. (Encoder uses True.) + static constexpr bool DECODER_CAUSAL = false; + + void init_params(ggml_context* ctx, const String2TensorStorage&, const std::string /*prefix*/ = "") override { + if (timestep_conditioning) { + // Python: last_scale_shift_table = Parameter(torch.empty(2, feature_channels)). + params["last_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, feature_channels, 2); + // timestep_scale_multiplier: scalar. + params["timestep_scale_multiplier"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + } + } + + public: + VideoDecoder() = default; + VideoDecoder(int latent_ch, int out_ch, int patch, int base_ch, bool timestep_cond, + const std::vector& dec_blocks) + : latent_channels(latent_ch), out_channels(out_ch), patch_size(patch), + base_channels(base_ch), timestep_conditioning(timestep_cond), + decoder_blocks(dec_blocks) { + // Decoder's feature_channels = base_channels * 8 per LTX-2 default (3 upsample blocks × 2). + feature_channels = base_ch * 8; + + blocks["conv_in"] = std::make_shared( + latent_ch, feature_channels, 3, std::tuple{1,1,1}, 1, true, PAD); + + // Decoder config is stored in encoder-side order; construct up_blocks in REVERSED order + // (matching the Python `list(reversed(decoder_blocks))`). + int cur_c = feature_channels; + for (size_t i = 0; i < decoder_blocks.size(); ++i) { + const auto& b = decoder_blocks[decoder_blocks.size() - 1 - i]; + std::string key = "up_blocks." + std::to_string(i); + switch (b.kind) { + case DecoderBlockKind::RES_X: + // Per-block timestep conditioning is independent from the top-level + // last-step conditioning. The 22B VAE only ships scale_shift_table + + // time_embedder weights for the inner up_blocks (those reachable from + // dec_specs entries with timestep_cond=true). + blocks[key] = std::make_shared(cur_c, b.num_layers, + timestep_conditioning && b.timestep_cond, + PAD); + break; + case DecoderBlockKind::COMPRESS_SPACE: + blocks[key] = std::make_shared(cur_c, std::tuple{1,2,2}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + case DecoderBlockKind::COMPRESS_TIME: + blocks[key] = std::make_shared(cur_c, std::tuple{2,1,1}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + case DecoderBlockKind::COMPRESS_ALL: + blocks[key] = std::make_shared(cur_c, std::tuple{2,2,2}, b.multiplier, PAD); + cur_c = cur_c / b.multiplier; + break; + } + } + + int final_out_ch = out_ch * patch * patch; + blocks["conv_out"] = std::make_shared( + cur_c, final_out_ch, 3, std::tuple{1,1,1}, 1, true, PAD); + + if (timestep_conditioning) { + blocks["last_time_embedder"] = std::make_shared(feature_channels * 2); + } + blocks["per_channel_statistics"] = std::make_shared(latent_ch); + } + + // Trace stage order (for parity debugging): + // 0 post_unnorm, 1 post_conv_in, 2..K-1 per up_block output, + // K post_pixel_norm (pre-ada), K+1 post_ada, K+2 post_conv_out, K+3 video_out. + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* latent, ggml_tensor* timestep = nullptr, + std::vector* trace_outputs = nullptr) { + auto pcs = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto x = pcs->un_normalize(ctx, latent); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + // Earlier dump_vae.py used default causal=True for conv_in, but the real Python + // decoder.forward uses self.causal which is False — the dumper is now aligned. + x = conv_in->forward(ctx, x, DECODER_CAUSAL); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + for (size_t i = 0; i < decoder_blocks.size(); ++i) { + const auto& b = decoder_blocks[decoder_blocks.size() - 1 - i]; + std::string key = "up_blocks." + std::to_string(i); + if (b.kind == DecoderBlockKind::RES_X) { + auto blk = std::dynamic_pointer_cast(blocks[key]); + x = blk->forward(ctx, x, timestep_conditioning ? timestep : nullptr, trace_outputs, DECODER_CAUSAL); + } else { + auto blk = std::dynamic_pointer_cast(blocks[key]); + x = blk->forward(ctx, x, DECODER_CAUSAL); + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + } + + // Final norm + AdaLN + SiLU + conv_out. + x = pixel_norm(ctx->ggml_ctx, x, eps); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto te = std::dynamic_pointer_cast(blocks["last_time_embedder"]); + // Python multiplies the timestep by timestep_scale_multiplier BEFORE the embed. + auto tsm = params["timestep_scale_multiplier"]; // scalar [1] + auto t_scaled = ggml_mul(ctx->ggml_ctx, timestep, tsm); + auto time_embed = te->forward(ctx, t_scaled); // ne=[2*feature_channels, 1] + + auto sst = params["last_scale_shift_table"]; // ne=[feature_channels, 2] + auto te2 = ggml_reshape_2d(ctx->ggml_ctx, time_embed, feature_channels, 2); + auto ada = ggml_add(ctx->ggml_ctx, te2, sst); // ne=[feature_channels, 2] + + auto shift = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 0, 1); + auto scale = ggml_ext_slice(ctx->ggml_ctx, ada, 1, 1, 2); + shift = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, shift), 1, 1, 1, feature_channels); + scale = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale), 1, 1, 1, feature_channels); + + auto x_scaled = ggml_mul(ctx->ggml_ctx, x, scale); + x = ggml_add(ctx->ggml_ctx, x, x_scaled); + x = ggml_add(ctx->ggml_ctx, x, shift); + } + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + x = ggml_silu(ctx->ggml_ctx, x); + + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + x = conv_out->forward(ctx, x, DECODER_CAUSAL); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + + x = unpatchify(ctx->ggml_ctx, x, 1, patch_size, patch_size); + if (trace_outputs) trace_outputs->push_back(ggml_cont(ctx->ggml_ctx, x)); + return x; + } + }; + + // ---------- GGMLRunner wrappers ---------- + + struct VAEEncoderRunner : public GGMLRunner { + VideoEncoder encoder; + int in_channels; + int latent_channels; + int patch_size; + + VAEEncoderRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int in_ch, + int latent_ch, + int patch, + const std::vector& specs) + : GGMLRunner(backend, offload_params_to_cpu), + encoder(in_ch, latent_ch, patch, specs), + in_channels(in_ch), latent_channels(latent_ch), patch_size(patch) { + encoder.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltx2_vae_encoder"; } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + encoder.get_param_tensors(tensors, prefix); + } + + // stage_index==-1 returns the final latent; >=0 returns the matching trace. + // Full forward is always built so buffer allocation covers every declared input. + sd::Tensor compute(int n_threads, const sd::Tensor& video_tensor, + int stage_index = -1) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* x = make_input(video_tensor); + auto runner_ctx = get_context(); + std::vector traces; + ggml_tensor* final_out = encoder.forward(&runner_ctx, x, &traces); + ggml_build_forward_expand(gf, final_out); + if (stage_index >= 0 && stage_index < (int)traces.size()) { + ggml_build_forward_expand(gf, traces[stage_index]); + } + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } + }; + + struct VAEDecoderRunner : public GGMLRunner { + VideoDecoder decoder; + + VAEDecoderRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int latent_ch, + int out_ch, + int patch, + int base_ch, + bool timestep_cond, + const std::vector& specs) + : GGMLRunner(backend, offload_params_to_cpu), + decoder(latent_ch, out_ch, patch, base_ch, timestep_cond, specs) { + decoder.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltx2_vae_decoder"; } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + decoder.get_param_tensors(tensors, prefix); + } + + // stage_index==-1 returns the final video output; >=0 returns the matching trace. + // We always build the FULL forward graph so every declared input has a backend + // buffer; when stage_index is set we just re-expand that trace last so it becomes + // the final-result tensor that GGMLRunner::compute extracts. + sd::Tensor compute(int n_threads, + const sd::Tensor& latent_tensor, + const sd::Tensor& timestep_tensor, + int stage_index = -1) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* z = make_input(latent_tensor); + ggml_tensor* t = timestep_tensor.empty() ? nullptr : make_input(timestep_tensor); + auto runner_ctx = get_context(); + std::vector traces; + ggml_tensor* final_out = decoder.forward(&runner_ctx, z, t, &traces); + ggml_build_forward_expand(gf, final_out); + if (stage_index >= 0 && stage_index < (int)traces.size()) { + ggml_build_forward_expand(gf, traces[stage_index]); + } + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } + }; + + // ---------- Combined VAE runner ---------- + // + // Plumbs both VideoEncoder and VideoDecoder into the shared VAE interface so + // create_vae() in stable-diffusion.cpp can treat LTX-2 like any other VAE. + // + // Prefix convention matches the real LTX-2 checkpoint: `vae.encoder.*`, + // `vae.decoder.*`, `vae.per_channel_statistics.*`. Since our VideoEncoder and + // VideoDecoder each register a PerChannelStatisticsBlock under their own + // sub-prefix, we need the state dict to have nested PCS copies (which our + // parity dumper provides). Real LTX-2 checkpoints only ship the top-level + // `vae.per_channel_statistics.*` — see FUTURE note below. + + struct LTX2VAERunner : public VAE { + VideoEncoder encoder; + VideoDecoder decoder; + float decode_timestep = 0.05f; // Python default. + bool uses_timestep_conditioning = true; + + LTX2VAERunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + SDVersion version_, + int in_ch = 3, + int latent_ch = 128, + int patch = 4, + int decoder_base_ch = 128, + bool timestep_cond = true, + std::vector enc_specs = {}, + std::vector dec_specs = {}) + : VAE(version_, backend, offload_params_to_cpu), + encoder(in_ch, latent_ch, patch, enc_specs.empty() ? default_enc_specs() : enc_specs), + decoder(latent_ch, in_ch, patch, decoder_base_ch, timestep_cond, + dec_specs.empty() ? default_dec_specs() : dec_specs), + uses_timestep_conditioning(timestep_cond) { + // LTX-2 callers pass already-[-1,1] RGB to encode (e.g. preprocessing.hpp + // hands us raw pixel values mapped to [-1, 1]), so the encoder must NOT + // re-scale [0,1]→[-1,1]. The decoder, however, still produces [-1,1] + // outputs that downstream tensor_to_sd_image expects mapped to [0, 1] — + // so we leave scale_output at its default `true` (added separately to + // the base VAE for asymmetric paths like this one). + scale_input = false; + encoder.init(params_ctx, tensor_storage_map, prefix + ".encoder"); + decoder.init(params_ctx, tensor_storage_map, prefix + ".decoder"); + } + + // Production default: 1× compress_space_res, 1× compress_time_res, 2× compress_all_res, + // per the LTXV paper "Standard LTX Video configuration" docstring. + static std::vector default_enc_specs() { + return { + {EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + }; + } + static std::vector default_dec_specs() { + // Stored in encoder-side order; VideoDecoder reverses. + return { + {DecoderBlockKind::COMPRESS_SPACE, 1, 1}, + {DecoderBlockKind::COMPRESS_TIME, 1, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + }; + } + + // Real 22B LTX-2 video VAE spec, reverse-engineered from the checkpoint's + // weight shapes (encoder ch progression: 128 → 256 → 512 → 1024 → 1024): + // idx kind cur_c after + // 0 RES_X(4 layers) 128 + // 1 COMPRESS_SPACE_RES(m=2) 128 → 256 + // 2 RES_X(6 layers) 256 + // 3 COMPRESS_TIME_RES(m=2) 256 → 512 + // 4 RES_X(4 layers) 512 + // 5 COMPRESS_ALL_RES(m=2) 512 → 1024 + // 6 RES_X(2 layers) 1024 + // 7 COMPRESS_ALL_RES(m=1) 1024 → 1024 (spatial/temporal compress only) + // 8 RES_X(2 layers) 1024 + // Final conv_out: 1024 → 129 (128 latent + 1 logvar). + // Decoder mirrors in encoder-side order; VideoDecoder reverses at construct. + static std::vector ltx2_22b_enc_specs() { + return { + {EncoderBlockKind::RES_X, 4, 1}, + {EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {EncoderBlockKind::RES_X, 6, 1}, + {EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {EncoderBlockKind::RES_X, 4, 1}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 2}, + {EncoderBlockKind::RES_X, 2, 1}, + {EncoderBlockKind::COMPRESS_ALL_RES, 1, 1}, + {EncoderBlockKind::RES_X, 2, 1}, + }; + } + static std::vector ltx2_22b_dec_specs() { + // Encoder-side order; VideoDecoder iterates in reverse at construct. + // Reverse iteration maps decoder_blocks[i] → up_blocks.[N-1-i], so the + // last entry here becomes up_blocks.0 (innermost, 1024 channels). + // + // Decoder channel progression (verified against real weight shapes): + // up_blocks.0 RES_X(2) @ 1024 + // up_blocks.1 COMPRESS_ALL(m=2) 1024 → 512 (conv:4096, d2s/8) + // up_blocks.2 RES_X(2) @ 512 + // up_blocks.3 COMPRESS_ALL(m=1) 512 → 512 (conv:4096, d2s/8) + // up_blocks.4 RES_X(4) @ 512 + // up_blocks.5 COMPRESS_TIME(m=2) 512 → 256 (conv:512, d2s/2) + // up_blocks.6 RES_X(6) @ 256 + // up_blocks.7 COMPRESS_SPACE(m=2) 256 → 128 (conv:512, d2s/4) + // up_blocks.8 RES_X(4) @ 128 + // Decoder's compress multipliers are NOT a mirror of the encoder's + // — the model is architecturally asymmetric (different res counts, different + // compress kinds at each level). Enc vs dec must each be traced separately. + // Per-block timestep_cond per the actual 22B checkpoint contents: + // NO up_block has per-res_block scale_shift_table + time_embedder + // weights — the LTX-2 22B VAE only conditions at the LAST step + // (last_scale_shift_table + last_time_embedder, gated by the + // top-level timestep_conditioning flag). + return { + {DecoderBlockKind::RES_X, 4, 1, /*timestep_cond=*/false}, + {DecoderBlockKind::COMPRESS_SPACE, 1, 2}, + {DecoderBlockKind::RES_X, 6, 1, /*timestep_cond=*/false}, + {DecoderBlockKind::COMPRESS_TIME, 1, 2}, + {DecoderBlockKind::RES_X, 4, 1, /*timestep_cond=*/false}, + {DecoderBlockKind::COMPRESS_ALL, 1, 1}, + {DecoderBlockKind::RES_X, 2, 1, /*timestep_cond=*/false}, + {DecoderBlockKind::COMPRESS_ALL, 1, 2}, + {DecoderBlockKind::RES_X, 2, 1, /*timestep_cond=*/false}, + }; + } + + std::string get_desc() override { return "ltx2_vae"; } + + void get_param_tensors(std::map& tensors, const std::string prefix) override { + encoder.get_param_tensors(tensors, prefix + ".encoder"); + decoder.get_param_tensors(tensors, prefix + ".decoder"); + } + + int get_encoder_output_channels(int /*input_channels*/) override { + return 128; // latent_channels + } + + sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, + std::shared_ptr /*rng*/) override { + return vae_output; + } + sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { + return latents; + } + sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { + return latents; + } + + ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { + // 10240 fit the 4-block parity test. The 22B VAE has 9 encoder + 9 + // decoder blocks with up to 6 res_blocks each, plus per-channel stats + // and conv_in/out. Bumped for safety. + ggml_cgraph* gf = new_graph_custom(65536); + ggml_tensor* z = make_input(z_tensor); + auto runner_ctx = get_context(); + ggml_tensor* out; + if (decode_graph) { + ggml_tensor* t = nullptr; + if (uses_timestep_conditioning) { + // Build a scalar timestep tensor inline (no external input needed). + t = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1); + ggml_set_name(t, "ltx2_vae_decode_timestep"); + decode_timestep_backing.resize(1); + decode_timestep_backing[0] = decode_timestep; + set_backend_tensor_data(t, decode_timestep_backing.data()); + } + out = decoder.forward(&runner_ctx, z, t); + } else { + out = encoder.forward(&runner_ctx, z); + } + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor _compute(const int n_threads, + const sd::Tensor& z, + bool decode_graph) override { + auto get_g = [&]() -> ggml_cgraph* { return build_graph(z, decode_graph); }; + auto out = take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + // Decoder output is [W, H, T, C]; decode_video_outputs + tensor_to_sd_image + // expect 5D [W, H, T, C, B] to pick the video-shaped index path. Add the + // trailing batch axis so the conversion uses the (iw, ih, frame, ic, 0) + // accessor (the 4D path assumes [W, H, C, F] which is the wrong layout). + if (decode_graph && !out.empty() && out.shape().size() == 4) { + auto s = out.shape(); + out.reshape_({s[0], s[1], s[2], s[3], 1}); + } + return out; + } + + private: + std::vector decode_timestep_backing; + }; + +} // namespace LTXVAE + +#endif // __LTXVAE_HPP__ diff --git a/src/ltxvae_primitives.hpp b/src/ltxvae_primitives.hpp new file mode 100644 index 000000000..05c301ecf --- /dev/null +++ b/src/ltxvae_primitives.hpp @@ -0,0 +1,212 @@ +#ifndef __LTXVAE_PRIMITIVES_HPP__ +#define __LTXVAE_PRIMITIVES_HPP__ + +#include "ggml.h" + +// Space-to-depth / depth-to-space helpers for the LTX-2 VAE. +// +// The VAE's `SpaceToDepthDownsample` and `DepthToSpaceUpsample` blocks compress +// or expand one or more of the (T, H, W) axes into/out of the channel axis. In +// einops notation (with B=1 elided): +// +// rearrange(x, "c (t p1) (h p2) (w p3) -> (c p1 p2 p3) t h w", ...) # space-to-depth +// rearrange(x, "(c p1 p2 p3) t h w -> c (t p1) (h p2) (w p3)", ...) # depth-to-space +// +// The einops grouping "(c p1 p2 p3)" puts p3 innermost (fastest-varying) within +// the merged channel axis, so c_new = c*p1*p2*p3 + i1*p2*p3 + i2*p3 + i3. +// +// GGML caps tensors at 4-D, which prevents a single reshape from representing the +// natural 5-D/6-D intermediate. We achieve the same result by folding the three +// strided axes ONE AT A TIME, composing three 4-D rearranges. The fold order +// matters: because each single-axis fold puts the just-folded factor innermost +// within the merged channel axis, folding in the order T→H→W produces p3 as the +// innermost factor in the final output — matching einops "(c p1 p2 p3)". +// +// Convention: all tensors use GGML ne=[W, H, T, C] (B=1 collapsed). A "factor" +// of 1 is a no-op; single-axis folds require the target axis to be divisible +// by factor. +// +// The primitives are verified byte-exact against PyTorch einops in the +// standalone test sd-s2d-primitives-test. + +namespace LTXVAE { + +// ---------- SpaceToDepth ---------- + +inline ggml_tensor* space_to_depth_axisW(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(W % factor == 0); + int64_t W_out = W / factor; + // Split innermost axis into [factor (innermost), W_out]. Merge H,T to stay 4D. + auto y = ggml_reshape_4d(ctx, x, factor, W_out, H * T, C); + // Move "factor" from axis 0 to axis 2 (adjacent to C). + // ggml_permute(a, p0, p1, p2, p3) says "old axis i goes to new position p_i". + // Here old→new: 0→2, 1→0, 2→1, 3→3. + y = ggml_cont(ctx, ggml_permute(ctx, y, 2, 0, 1, 3)); // ne=[W_out, H*T, factor, C] + // Merge (factor, C) with factor innermost of the new channel axis, matching einops (c p3). + y = ggml_reshape_4d(ctx, y, W_out, H, T, C * factor); + return y; +} + +inline ggml_tensor* space_to_depth_axisH(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(H % factor == 0); + int64_t H_out = H / factor; + auto y = ggml_reshape_4d(ctx, x, W, factor, H_out * T, C); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W, H*T, factor, C] + y = ggml_reshape_4d(ctx, y, W, H_out, T, C * factor); + return y; +} + +inline ggml_tensor* space_to_depth_axisT(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(T % factor == 0); + int64_t T_out = T / factor; + auto y = ggml_reshape_4d(ctx, x, W * H, factor, T_out, C); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W*H, T, factor, C] + y = ggml_reshape_4d(ctx, y, W, H, T_out, C * factor); + return y; +} + +// Compose: fold T first (so p1 ends up outer), then H (p2), then W (p3 innermost) +// — matching einops "(c p1 p2 p3)" channel ordering. +inline ggml_tensor* space_to_depth(ggml_context* ctx, ggml_tensor* x, + int p1, int p2, int p3) { + if (p1 > 1) x = space_to_depth_axisT(ctx, x, p1); + if (p2 > 1) x = space_to_depth_axisH(ctx, x, p2); + if (p3 > 1) x = space_to_depth_axisW(ctx, x, p3); + return x; +} + +// ---------- DepthToSpace (inverse) ---------- +// +// Each single-axis depth-to-space splits the last axis (C_in = C_out * factor) +// with factor innermost, moves factor to the strided spatial axis, then merges. +// To invert space_to_depth's T→H→W fold order, we unfold in reverse: W→H→T. + +inline ggml_tensor* depth_to_space_axisW(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + // Split last axis into [factor (innermost), C_out]. + auto y = ggml_reshape_4d(ctx, x, W, H * T, factor, C_out); + // Inverse of the S2D-axisW permute (2,0,1,3). Inverse of that map is (1,2,0,3): + // old 0→new 1, old 1→new 2, old 2→new 0, old 3→new 3. + y = ggml_cont(ctx, ggml_permute(ctx, y, 1, 2, 0, 3)); // ne=[factor, W, H*T, C_out] + y = ggml_reshape_4d(ctx, y, W * factor, H, T, C_out); + return y; +} + +inline ggml_tensor* depth_to_space_axisH(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + auto y = ggml_reshape_4d(ctx, x, W, H * T, factor, C_out); + // Inverse of S2D-axisH's (0,2,1,3) is itself (0,2,1,3). + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W, factor, H*T, C_out] + y = ggml_reshape_4d(ctx, y, W, H * factor, T, C_out); + return y; +} + +inline ggml_tensor* depth_to_space_axisT(ggml_context* ctx, ggml_tensor* x, int factor) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + GGML_ASSERT(C % factor == 0); + int64_t C_out = C / factor; + auto y = ggml_reshape_4d(ctx, x, W * H, T, factor, C_out); + y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3)); // ne=[W*H, factor, T, C_out] + y = ggml_reshape_4d(ctx, y, W, H, T * factor, C_out); + return y; +} + +// Inverse of space_to_depth: unfold in reverse order (W first, then H, then T) +// because S2D folded T first, H, W. +inline ggml_tensor* depth_to_space(ggml_context* ctx, ggml_tensor* x, + int p1, int p2, int p3) { + if (p3 > 1) x = depth_to_space_axisW(ctx, x, p3); + if (p2 > 1) x = depth_to_space_axisH(ctx, x, p2); + if (p1 > 1) x = depth_to_space_axisT(ctx, x, p1); + return x; +} + +// ---------- patchify / unpatchify ---------- +// +// The VAE's patchify op uses a DIFFERENT channel ordering from the Downsample/Upsample +// blocks: einops `"b c (f p) (h q) (w r) -> b (c p r q) f h w"` — innermost within the +// merged channel axis is q (H-patch), NOT p3/W as elsewhere. To match, we fold in the +// order T (p), W (r), H (q) — last fold ends up innermost. + +inline ggml_tensor* patchify(ggml_context* ctx, ggml_tensor* x, int pt, int ph, int pw) { + if (pt > 1) x = space_to_depth_axisT(ctx, x, pt); + if (pw > 1) x = space_to_depth_axisW(ctx, x, pw); + if (ph > 1) x = space_to_depth_axisH(ctx, x, ph); + return x; +} + +inline ggml_tensor* unpatchify(ggml_context* ctx, ggml_tensor* x, int pt, int ph, int pw) { + if (ph > 1) x = depth_to_space_axisH(ctx, x, ph); + if (pw > 1) x = depth_to_space_axisW(ctx, x, pw); + if (pt > 1) x = depth_to_space_axisT(ctx, x, pt); + return x; +} + +// ---------- PixelNorm ---------- +// +// Python (ltx_core.model.common.normalization.PixelNorm, dim=1): +// y = x / sqrt(mean(x^2, dim=1, keepdim=True) + eps) +// PyTorch dim=1 is the channel axis. In our GGML layout ne=[W, H, T, C] that's +// ne[3] (outermost). ggml_rms_norm normalizes along ne[0] (innermost), so we +// permute C to innermost, rms-normalize, then permute back. +// +// This has NO learnable parameters — the Python PixelNorm is parameter-free. + +inline ggml_tensor* pixel_norm(ggml_context* ctx, ggml_tensor* x, float eps) { + int64_t W = x->ne[0], H = x->ne[1], T = x->ne[2], C = x->ne[3]; + // Move C to innermost. old→new: 0→1 (W to pos 1), 1→2 (H to 2), 2→3 (T to 3), 3→0 (C to 0). + auto y = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 3, 0)); // ne=[C, W, H, T] + y = ggml_rms_norm(ctx, y, eps); // normalize along ne[0]=C + // Permute back: C to outermost. old→new: 0→3, 1→0, 2→1, 3→2. + y = ggml_cont(ctx, ggml_permute(ctx, y, 3, 0, 1, 2)); // ne=[W, H, T, C] + (void)W; (void)H; (void)T; (void)C; + return y; +} + +// ---------- PerChannelStatistics ---------- +// +// Python: buffers `mean-of-means` [C] and `std-of-means` [C]. +// normalize(x) = (x - mean) / std +// un_normalize(x) = x * std + mean +// In GGML with ne=[W, H, T, C] and a 1D buffer of shape [C] (ne=[C, 1, 1, 1]), +// we broadcast over W*H*T by using the asymmetric-broadcast ggml_add/mul: +// ggml_mul(a, b) requires a->ne[i] % b->ne[i] == 0, so we pass x as `a` and the +// [C] buffer reshaped to ne=[1, 1, 1, C] as `b` — same outermost-axis shape. + +inline ggml_tensor* pcs_normalize(ggml_context* ctx, ggml_tensor* x, + ggml_tensor* mean_of_means, + ggml_tensor* std_of_means) { + int64_t C = x->ne[3]; + // Reshape both buffers to ne=[1, 1, 1, C] so they broadcast along W/H/T. + auto mu = ggml_reshape_4d(ctx, mean_of_means, 1, 1, 1, C); + auto sigma = ggml_reshape_4d(ctx, std_of_means, 1, 1, 1, C); + // (x - mu) / sigma = (x - mu) * (1/sigma). Compute the reciprocal by dividing. + // ggml doesn't have a direct div by tensor; emulate with ggml_div if available, + // else compute inv_sigma on the host. Since sigma is a loaded buffer (constant + // at inference), the cheapest is to do: x_shifted = x - mu; x_norm = x_shifted / sigma. + auto x_shifted = ggml_sub(ctx, x, mu); + auto x_norm = ggml_div(ctx, x_shifted, sigma); + return x_norm; +} + +inline ggml_tensor* pcs_unnormalize(ggml_context* ctx, ggml_tensor* x, + ggml_tensor* mean_of_means, + ggml_tensor* std_of_means) { + int64_t C = x->ne[3]; + auto mu = ggml_reshape_4d(ctx, mean_of_means, 1, 1, 1, C); + auto sigma = ggml_reshape_4d(ctx, std_of_means, 1, 1, 1, C); + auto y = ggml_mul(ctx, x, sigma); + y = ggml_add(ctx, y, mu); + return y; +} + +} // namespace LTXVAE + +#endif diff --git a/src/model.cpp b/src/model.cpp index 3479a0bea..b169a6533 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -471,6 +471,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) { return VERSION_ERNIE_IMAGE; } + if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.scale_shift_table") != std::string::npos) { + return VERSION_LTX2; + } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; } @@ -1019,6 +1022,15 @@ bool ModelLoader::load_tensors(std::map& tensors, bool enable_mmap) { std::set tensor_names_in_file; std::mutex tensor_names_mutex; + // SD_QUIET_UNKNOWN_TENSORS=1 silences the per-tensor "unknown tensor" + // log line. The LTX-2 pipeline ships hundreds of audio / encoder + // tensors that the video-only path doesn't consume, swamping the log. + // Default is OFF (current behaviour, log every unknown). + const char* env_quiet_unknown = std::getenv("SD_QUIET_UNKNOWN_TENSORS"); + const bool quiet_unknown = env_quiet_unknown != nullptr && env_quiet_unknown[0] != '\0' && + env_quiet_unknown[0] != '0'; + size_t unknown_count = 0; + std::mutex unknown_mutex; auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); @@ -1036,7 +1048,12 @@ bool ModelLoader::load_tensors(std::map& tensors, return true; } } - LOG_INFO("unknown tensor '%s' in model file", tensor_storage.to_string().c_str()); + if (quiet_unknown) { + std::lock_guard lock(unknown_mutex); + unknown_count++; + } else { + LOG_INFO("unknown tensor '%s' in model file", tensor_storage.to_string().c_str()); + } return true; } @@ -1085,6 +1102,9 @@ bool ModelLoader::load_tensors(std::map& tensors, if (some_tensor_not_init) { return false; } + if (quiet_unknown && unknown_count > 0) { + LOG_INFO("skipped %zu unknown tensors (SD_QUIET_UNKNOWN_TENSORS=1)", unknown_count); + } return true; } @@ -1134,6 +1154,9 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) } if (tensor_should_be_converted(tensor_storage, type)) { tensor_storage.type = type; + } else if (tensor_storage.expected_type != GGML_TYPE_COUNT && + tensor_storage.expected_type != tensor_storage.type) { + tensor_storage.type = tensor_storage.expected_type; } mem_size += tensor_storage.nbytes() + alignment; } diff --git a/src/model.h b/src/model.h index 65bc6c367..fd8ba6f21 100644 --- a/src/model.h +++ b/src/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_ERNIE_IMAGE, + VERSION_LTX2, VERSION_COUNT, }; @@ -139,6 +140,13 @@ static inline bool sd_version_is_ernie_image(SDVersion version) { return false; } +static inline bool sd_version_is_ltx2(SDVersion version) { + if (version == VERSION_LTX2) { + return true; + } + return false; +} + static inline bool sd_version_uses_flux2_vae(SDVersion version) { if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) { return true; @@ -165,7 +173,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_z_image(version) || - sd_version_is_ernie_image(version)) { + sd_version_is_ernie_image(version) || + sd_version_is_ltx2(version)) { return true; } return false; @@ -193,6 +202,8 @@ using TensorTypeRules = std::vector>; TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); +bool is_unused_tensor(const std::string& name); + class ModelLoader { protected: SDVersion version_ = VERSION_COUNT; diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index 618c7f6e9..fb91ab346 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -653,6 +653,39 @@ std::string convert_diffusers_dit_to_original_lumina2(std::string name) { return name; } +std::string convert_diffusers_dit_to_original_ltx2(std::string name) { + // Maps diffusers' LTX Video Transformer naming → original LTX-2 naming. + // The GGML block tree mirrors the original Python class attribute names, so anything matching the + // original naming passes through. Only the few diffusers-specific renames are listed here. + static std::unordered_map ltx2_name_map; + + if (ltx2_name_map.empty()) { + // Input projection: diffusers names it x_embedder; original uses patchify_proj. + ltx2_name_map["x_embedder.weight"] = "patchify_proj.weight"; + ltx2_name_map["x_embedder.bias"] = "patchify_proj.bias"; + + // Timestep head: diffusers sometimes puts these under time_embed / time_text_embed while the + // original LTX-2 uses adaln_single.emb.timestep_embedder.linear_{1,2} and adaln_single.linear. + ltx2_name_map["time_embed.timestep_embedder.linear_1.weight"] = "adaln_single.emb.timestep_embedder.linear_1.weight"; + ltx2_name_map["time_embed.timestep_embedder.linear_1.bias"] = "adaln_single.emb.timestep_embedder.linear_1.bias"; + ltx2_name_map["time_embed.timestep_embedder.linear_2.weight"] = "adaln_single.emb.timestep_embedder.linear_2.weight"; + ltx2_name_map["time_embed.timestep_embedder.linear_2.bias"] = "adaln_single.emb.timestep_embedder.linear_2.bias"; + ltx2_name_map["time_embed.linear.weight"] = "adaln_single.linear.weight"; + ltx2_name_map["time_embed.linear.bias"] = "adaln_single.linear.bias"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "adaln_single.emb.timestep_embedder.linear_1.weight"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "adaln_single.emb.timestep_embedder.linear_1.bias"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_2.weight"] = "adaln_single.emb.timestep_embedder.linear_2.weight"; + ltx2_name_map["time_text_embed.timestep_embedder.linear_2.bias"] = "adaln_single.emb.timestep_embedder.linear_2.bias"; + + // Transformer block names typically match (attn1/attn2/ff/scale_shift_table), so nothing to rewrite. + // Output projection & scale_shift_table pass through. + } + + replace_with_prefix_map(name, ltx2_name_map); + + return name; +} + std::string convert_other_dit_to_original_anima(std::string name) { static const std::string anima_net_prefix = "net."; if (!starts_with(name, anima_net_prefix)) { @@ -672,6 +705,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); + } else if (sd_version_is_ltx2(version)) { + name = convert_diffusers_dit_to_original_ltx2(name); } else if (sd_version_is_anima(version)) { name = convert_other_dit_to_original_anima(name); } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index b9d3e9af1..3d831eb51 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -8,6 +8,7 @@ #include "util.h" #include "auto_encoder_kl.hpp" +#include "backend_fit.hpp" #include "conditioner.hpp" #include "control.hpp" #include "denoiser.hpp" @@ -16,6 +17,7 @@ #include "lora.hpp" #include "pmid.hpp" #include "sample-cache.h" +#include "ltxvae.hpp" #include "tae.hpp" #include "vae.hpp" @@ -54,6 +56,7 @@ const char* model_version_to_str[] = { "Z-Image", "Ovis Image", "Ernie Image", + "LTX-2", }; const char* sampling_methods_str[] = { @@ -111,6 +114,12 @@ class StableDiffusionGGML { ggml_backend_t clip_backend = nullptr; ggml_backend_t control_net_backend = nullptr; ggml_backend_t vae_backend = nullptr; + // Actual device id that `backend` points at. `SD_CUDA_DEVICE` can go stale + // when auto-fit re-initialises `backend` onto a different GPU. We track the + // live value so per-component resolution can decide "same device as main" + // correctly. -1 means the main backend is CPU. + int backend_device_id = -1; + static constexpr int BACKEND_DEVICE_CPU = -1; SDVersion version; bool vae_decode_only = false; @@ -148,6 +157,93 @@ class StableDiffusionGGML { bool is_using_v_parameterization = false; bool is_using_edm_v_parameterization = false; + // Populated by auto-fit (when --auto-fit is passed). When enabled, this + // overrides env-var based per-component placement. device_id == -1 means + // "no override" (fall through to env vars / defaults). + struct FitOverride { + bool enabled = false; + int dit_device_id = -1; // -1 = keep main backend + int vae_device_id = -2; // -2 = no override (distinguishes from "force CPU") + int cond_device_id = -2; + bool dit_offload_params = false; // force offload_params_to_cpu for DiT only + bool cond_offload_params = false; // force offload_params_to_cpu for Conditioner only + bool vae_offload_params = false; + bool vae_on_cpu = false; + bool cond_on_cpu = false; + }; + FitOverride fit_override; + + // Auto-fit VAE tiling: when auto_fit is enabled, generate_video / + // generate_image consult this budget at gen time to decide whether the + // current resolution needs VAE tiling (and at what tile size). Captured + // from sd_ctx_params at init so it survives past sd_ctx_params_t scope. + bool auto_fit_enabled = false; + int64_t auto_fit_vae_compute_reserve_bytes = 0; + bool auto_fit_vae_on_cpu = false; + + // Pending tensor-split decisions from auto-fit (consumed by init_tensor_split). + // When true, the corresponding component will be configured for split mode + // even when SD_CUDA_TENSOR_SPLIT_DIT / _COND env vars are unset. + bool pending_split_dit = false; + bool pending_split_cond = false; + + // Heuristic: VAE peak compute scales with the largest activation tensor, + // which is roughly proportional to the number of latent voxels times a + // per-voxel byte cost dominated by the deepest decoder block (1024-channel + // up_block in LTX-2). 1 MiB/voxel is an empirical fit that lets a typical + // 480x320x25 LTX-2 decode (15*10*4=600 voxels) decode in one pass under + // the default 1024 MiB reserve while triggering tiling at HD+ (e.g. + // 1280x720x49 → 40*22*7=6160 voxels, tile side ≈ 12). + static constexpr int64_t LATENT_VOXEL_PEAK_BYTES = 1 * backend_fit::MiB; + + // Compute auto-tiling parameters for the VAE based on the captured + // auto-fit budget and the latent grid (lw × lh × t_latent). Modifies + // `vae_tiling_params` in place ONLY when: + // - auto-fit is enabled + // - the user did NOT explicitly enable tiling (we don't override) + // - the predicted peak exceeds the reserved budget + // Logs the chosen tile size; no-op otherwise. + void maybe_auto_set_vae_tiling(int64_t lw, int64_t lh, int64_t t_latent) { + if (!auto_fit_enabled) return; + if (vae_tiling_params.enabled) return; // user override: don't touch + if (auto_fit_vae_on_cpu) return; // CPU VAE: no VRAM budget + if (auto_fit_vae_compute_reserve_bytes <= 0) return; + if (lw <= 0 || lh <= 0 || t_latent <= 0) return; + + const int64_t total_voxels = lw * lh * t_latent; + const int64_t budget_voxels = + auto_fit_vae_compute_reserve_bytes / LATENT_VOXEL_PEAK_BYTES; + if (total_voxels <= budget_voxels) { + // Fits in a single decode pass — leave tiling disabled. + return; + } + + // Time stays untiled (the LTX-2 VAE has temporal coupling); spread + // the budget across spatial tiles. Pick a square tile_w × tile_h. + const int64_t voxels_per_tile = std::max(budget_voxels, 1); + const int64_t tile_area_max = std::max(voxels_per_tile / t_latent, 16); + int tile_side = static_cast(std::round(std::sqrt(double(tile_area_max)))); + // Clamp to [8, max(lw, lh)] so the VAE tiler doesn't get a degenerate + // tile size, and never go above the actual latent dim. + tile_side = std::max(8, tile_side); + int tile_x = std::min(tile_side, static_cast(lw)); + int tile_y = std::min(tile_side, static_cast(lh)); + + vae_tiling_params.enabled = true; + vae_tiling_params.tile_size_x = tile_x; + vae_tiling_params.tile_size_y = tile_y; + vae_tiling_params.target_overlap = 0.5f; + vae_tiling_params.rel_size_x = 0.f; + vae_tiling_params.rel_size_y = 0.f; + + LOG_INFO("auto-fit: VAE tiling enabled (latent %lldx%lldx%lld = %lld voxels > " + "budget %lld voxels @ %lld MiB reserve); tile %dx%d (latent), overlap %.2f", + (long long)lw, (long long)lh, (long long)t_latent, (long long)total_voxels, + (long long)budget_voxels, + (long long)(auto_fit_vae_compute_reserve_bytes / backend_fit::MiB), + tile_x, tile_y, vae_tiling_params.target_overlap); + } + std::map tensors; // lora_name => multiplier @@ -155,6 +251,31 @@ class StableDiffusionGGML { std::shared_ptr denoiser = std::make_shared(); + // ModelLoader kept alive for the lifetime of the SD context. Lazy-load + // callbacks (registered on DiT / LLM runners) call back into this loader + // on the first compute() of each component to read weights from disk + // sequentially — keeps peak VRAM at max-across-phases instead of + // sum-of-components. Toggled per-component via SD_LAZY_LOAD_DIT / + // SD_LAZY_LOAD_COND env vars. + std::unique_ptr model_loader_; + std::set load_ignore_tensors; + bool lazy_load_dit = false; + bool lazy_load_cond = false; + + // Multi-GPU tensor split state. Populated when SD_CUDA_TENSOR_SPLIT is set. + // The extra GPU backends and the CPU fallback are owned here so they live + // as long as any GGMLRunner that references them via MultiBackendSpec. + struct TensorSplitState { + bool enabled = false; + std::vector ratios; // per-device row-split ratios + int main_device = 0; + std::vector extra_backends; // additional GPU backends (excluding main) + ggml_backend_t cpu_fallback = nullptr; + bool split_dit = false; // apply to LTX-2 DiT + bool split_cond = false; // apply to conditioner (LLM/Gemma) + }; + TensorSplitState tensor_split_state; + StableDiffusionGGML() = default; ~StableDiffusionGGML() { @@ -167,13 +288,110 @@ class StableDiffusionGGML { if (vae_backend != backend) { ggml_backend_free(vae_backend); } + // Tensor-split extra GPUs + CPU fallback: free after the runners (they + // refer to these via MultiBackendSpec). Order: extras first, then CPU + // fallback, then main `backend`. + for (auto* b : tensor_split_state.extra_backends) { + if (b != nullptr && b != backend) { + ggml_backend_free(b); + } + } + tensor_split_state.extra_backends.clear(); + if (tensor_split_state.cpu_fallback != nullptr) { + ggml_backend_free(tensor_split_state.cpu_fallback); + tensor_split_state.cpu_fallback = nullptr; + } ggml_backend_free(backend); } + // Read an integer environment variable, returning `def` if unset or malformed. + static int get_env_int(const char* name, int def) { + const char* v = getenv(name); + if (v == nullptr || *v == '\0') return def; + try { + return std::stoi(v); + } catch (...) { + LOG_WARN("env %s: '%s' is not a valid integer, using default %d", name, v, def); + return def; + } + } + + // Initialize a GPU backend for the given device id, or fall back to CPU. + // For CUDA, device_id < 0 means "CPU only"; otherwise clamp to available count. + // `component_name` is used for log messages (e.g. "DiT", "Gemma", "VAE"). + static ggml_backend_t init_device_backend(int device_id, const char* component_name) { + if (device_id < 0) { + LOG_INFO("%s: using CPU backend (device=-1)", component_name); + return ggml_backend_cpu_init(); + } +#ifdef SD_USE_CUDA + int count = ggml_backend_cuda_get_device_count(); + if (count <= 0) { + LOG_WARN("%s: no CUDA devices available, falling back to CPU", component_name); + return ggml_backend_cpu_init(); + } + if (device_id >= count) { + LOG_WARN("%s: CUDA device %d requested but only %d available, falling back to device 0", + component_name, device_id, count); + device_id = 0; + } + auto b = ggml_backend_cuda_init(device_id); + if (b != nullptr) { + LOG_INFO("%s: using CUDA device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: CUDA device %d init failed, falling back to CPU", component_name, device_id); + return ggml_backend_cpu_init(); +#elif defined(SD_USE_VULKAN) + int count = ggml_backend_vk_get_device_count(); + if (count <= 0) { + LOG_WARN("%s: no Vulkan devices available, falling back to CPU", component_name); + return ggml_backend_cpu_init(); + } + if (device_id >= count) { + LOG_WARN("%s: Vulkan device %d requested but only %d available, falling back to device 0", + component_name, device_id, count); + device_id = 0; + } + auto b = ggml_backend_vk_init((size_t)device_id); + if (b != nullptr) { + LOG_INFO("%s: using Vulkan device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: Vulkan device %d init failed, falling back to CPU", component_name, device_id); + return ggml_backend_cpu_init(); +#elif defined(SD_USE_SYCL) + auto b = ggml_backend_sycl_init(device_id); + if (b != nullptr) { + LOG_INFO("%s: using SYCL device %d", component_name, device_id); + return b; + } + LOG_WARN("%s: SYCL init failed, falling back to CPU", component_name); + return ggml_backend_cpu_init(); +#else + (void)device_id; + LOG_INFO("%s: using CPU backend", component_name); + return ggml_backend_cpu_init(); +#endif + } + + // Main backend init. Honours these env vars for per-component device placement + // (used by the init path below): + // SD_CUDA_DEVICE default CUDA device id (default 0) — also used for DiT + // SD_CUDA_DEVICE_CLIP text encoder / conditioner (falls back to SD_CUDA_DEVICE) + // SD_CUDA_DEVICE_VAE VAE (falls back to SD_CUDA_DEVICE) + // SD_CUDA_DEVICE_CONTROL ControlNet (falls back to SD_CUDA_DEVICE) + // SD_VK_DEVICE same pattern for the Vulkan build + // Setting any of these to -1 forces CPU for that component. + // + // `keep_clip_on_cpu` / `keep_vae_on_cpu` still take precedence and force CPU. + // For weights too big even for a dedicated device, use offload_params_to_cpu + // (keeps weights on CPU and streams per-step to GPU). void init_backend() { #ifdef SD_USE_CUDA - LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(0); + int main_dev = get_env_int("SD_CUDA_DEVICE", 0); + backend = init_device_backend(main_dev, "main"); + backend_device_id = ggml_backend_is_cpu(backend) ? BACKEND_DEVICE_CPU : main_dev; #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); @@ -227,6 +445,205 @@ class StableDiffusionGGML { } } + // Resolve the backend for a sub-component by reading its env override (if set), + // otherwise reusing the main backend. Returns the main `backend` unchanged if + // the override matches the main device; otherwise creates a new backend (which + // the caller is responsible for freeing via the existing `!= backend` dtor check). + // `force_cpu` short-circuits to CPU regardless of the env var. + // `fit_device_id` is the auto-fit override: -2 means "no override", -1 means + // "force CPU", >=0 names a specific GPU. + ggml_backend_t resolve_component_backend(const char* env_name, + const char* component_name, + bool force_cpu, + int fit_device_id = -2) { + if (force_cpu) { + if (ggml_backend_is_cpu(backend)) { + return backend; + } + LOG_INFO("%s: forced CPU backend", component_name); + return ggml_backend_cpu_init(); + } +#if defined(SD_USE_CUDA) || defined(SD_USE_VULKAN) || defined(SD_USE_SYCL) + // Reuse the main backend iff this component resolves to the same + // physical device. After auto-fit re-initialises the main backend + // onto a different GPU, `SD_CUDA_DEVICE` no longer reflects reality, + // so we compare against `backend_device_id` instead. + int override_dev; + if (fit_override.enabled && fit_device_id != -2) { + override_dev = fit_device_id; + } else { + override_dev = get_env_int(env_name, get_env_int("SD_CUDA_DEVICE", 0)); + } + if (override_dev == backend_device_id && !ggml_backend_is_cpu(backend)) { + return backend; + } + return init_device_backend(override_dev, component_name); +#else + (void)env_name; + (void)component_name; + (void)fit_device_id; + return backend; +#endif + } + + // Configure tensor split for multi-GPU systems. + // + // Source of ratios (in priority order): + // 1. SD_CUDA_TENSOR_SPLIT="W0,W1,..." — explicit user override. + // 2. sd_ctx_params->auto_tensor_split (default true): when >1 CUDA + // device is detected and the env is unset, compute ratios from the + // free VRAM of each device and apply them to the DiT. + // + // SD_CUDA_TENSOR_SPLIT_DIT=1 enables split for the LTX-2 DiT. + // SD_CUDA_TENSOR_SPLIT_COND=1 enables split for the conditioner (Gemma). + // Both default off when only the ratios env is set, so the user can dial + // in just one component. + void init_tensor_split(bool auto_tensor_split) { +#ifdef SD_USE_CUDA + if (tensor_split_state.enabled) return; + const int dev_count = ggml_backend_cuda_get_device_count(); + + std::vector ratios; + bool from_env = false; + if (const char* split_env = getenv("SD_CUDA_TENSOR_SPLIT"); + split_env != nullptr && *split_env != '\0') { + from_env = true; + std::string s(split_env); + size_t i = 0; + while (i < s.size()) { + size_t j = s.find(',', i); + std::string tok = s.substr(i, (j == std::string::npos) ? std::string::npos : j - i); + try { + float v = std::stof(tok); + ratios.push_back(v); + } catch (...) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT: bad token '%s' — disabling tensor split", + tok.c_str()); + return; + } + if (j == std::string::npos) break; + i = j + 1; + } + } else if (auto_tensor_split && dev_count >= 2) { + // Auto-derive: ratios proportional to free VRAM, one per device. + ratios.reserve(dev_count); + for (int d = 0; d < dev_count; d++) { + size_t free_b = 0, total_b = 0; + ggml_backend_cuda_get_device_memory(d, &free_b, &total_b); + ratios.push_back(static_cast(free_b) / float(backend_fit::MiB)); + } + LOG_INFO("auto tensor split: deriving DiT ratios from free VRAM " + "across %d CUDA device(s)", + dev_count); + } else { + return; // single GPU, env unset, auto disabled, or no CUDA + } + + if (dev_count < 2) { + if (from_env) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT set but only %d CUDA device(s) available; ignoring", + dev_count); + } + return; + } + if ((int)ratios.size() > dev_count) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT has %zu ratios but only %d CUDA devices; truncating", + ratios.size(), dev_count); + ratios.resize(dev_count); + } + // Pad with zeros so downstream code can index by device id without bounds checks. + while ((int)ratios.size() < dev_count) ratios.push_back(0.f); + + int main_dev = backend_device_id; + if (main_dev < 0) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT: main backend is not CUDA; ignoring tensor split"); + return; + } + // Determine which non-main devices have non-zero ratio — those need + // their own ggml_backend_cuda instance for sched routing. + std::vector participating_devices; + for (int d = 0; d < dev_count; d++) { + if (ratios[d] > 0.f && d != main_dev) { + participating_devices.push_back(d); + } + } + if (participating_devices.empty()) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT: no non-main device has nonzero ratio; ignoring"); + return; + } + + tensor_split_state.enabled = true; + tensor_split_state.ratios = std::move(ratios); + tensor_split_state.main_device = main_dev; + + for (int d : participating_devices) { + ggml_backend_t b = ggml_backend_cuda_init(d); + if (b == nullptr) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT: failed to init CUDA device %d; skipping", d); + continue; + } + tensor_split_state.extra_backends.push_back(b); + } + if (tensor_split_state.extra_backends.empty()) { + LOG_WARN("SD_CUDA_TENSOR_SPLIT: no extra CUDA backends could be initialised; disabling"); + tensor_split_state.enabled = false; + return; + } + + tensor_split_state.cpu_fallback = ggml_backend_cpu_init(); + // Env vars take priority; otherwise honour what auto-fit decided + // (pending_split_dit / pending_split_cond). If neither: default to + // DiT-only split for the multi-GPU case. + tensor_split_state.split_dit = get_env_int("SD_CUDA_TENSOR_SPLIT_DIT", 0) != 0 + || pending_split_dit; + tensor_split_state.split_cond = get_env_int("SD_CUDA_TENSOR_SPLIT_COND", 0) != 0 + || pending_split_cond; + if (!tensor_split_state.split_dit && !tensor_split_state.split_cond) { + tensor_split_state.split_dit = true; + LOG_INFO("SD_CUDA_TENSOR_SPLIT: neither SD_CUDA_TENSOR_SPLIT_DIT nor _COND set; " + "defaulting to DiT-only split"); + } + + std::string ratio_str; + for (size_t k = 0; k < tensor_split_state.ratios.size(); k++) { + if (k > 0) ratio_str += ","; + char buf[32]; + std::snprintf(buf, sizeof(buf), "%.2f", tensor_split_state.ratios[k]); + ratio_str += buf; + } + LOG_INFO("tensor split enabled: ratios=[%s] main_dev=%d extras=%zu DiT=%s Cond=%s", + ratio_str.c_str(), main_dev, tensor_split_state.extra_backends.size(), + tensor_split_state.split_dit ? "yes" : "no", + tensor_split_state.split_cond ? "yes" : "no"); +#endif + } + + // Populate g_pending_multi_backend_spec() with this runner's split config + // if `apply` is true. The pending spec is consumed (cleared) by the next + // GGMLRunner ctor. Returns a guard struct that clears the pending spec on + // destruction (in case construction throws and the runner ctor doesn't + // run). + struct PendingSplitGuard { + MultiBackendSpec spec; + bool active = false; + ~PendingSplitGuard() { + if (active && g_pending_multi_backend_spec() == &spec) { + g_pending_multi_backend_spec() = nullptr; + } + } + }; + std::unique_ptr begin_pending_split(bool apply) { + if (!apply || !tensor_split_state.enabled) return nullptr; + auto g = std::make_unique(); + g->spec.extra_backends = tensor_split_state.extra_backends; + g->spec.tensor_split = tensor_split_state.ratios; + g->spec.main_device = tensor_split_state.main_device; + g->spec.cpu_fallback = tensor_split_state.cpu_fallback; + g->active = true; + g_pending_multi_backend_spec() = &g->spec; + return g; + } + std::shared_ptr get_rng(rng_type_t rng_type) { if (rng_type == STD_DEFAULT_RNG) { return std::make_shared(); @@ -255,8 +672,22 @@ class StableDiffusionGGML { ggml_log_set(ggml_log_callback_default, nullptr); init_backend(); - - ModelLoader model_loader; + // tensor split is initialised AFTER auto-fit, since auto-fit may decide + // to place a component in tensor-split mode. + + // Lazy load — let DiT and Conditioner-LLM lazy-load weights on first + // compute() instead of all-at-once at init. Required when sum-of- + // components exceeds combined VRAM (e.g. Q6_K LTX-2 + Q8_K_XL Gemma + // + connector + VAE on a 24 GB combined system). Defaults to ON via + // sd_ctx_params; env vars (SD_LAZY_LOAD_DIT / SD_LAZY_LOAD_COND) act + // as force-on overrides (they do not disable). + lazy_load_dit = sd_ctx_params->lazy_load_dit || get_env_int("SD_LAZY_LOAD_DIT", 0) != 0; + lazy_load_cond = sd_ctx_params->lazy_load_cond || get_env_int("SD_LAZY_LOAD_COND", 0) != 0; + if (lazy_load_dit) LOG_INFO("lazy load: DiT (alloc + read on first compute, free after free_params_immediately)"); + if (lazy_load_cond) LOG_INFO("lazy load: LLM (Gemma allocs on first encode; connector stays eager)"); + + model_loader_ = std::make_unique(); + ModelLoader& model_loader = *model_loader_; if (strlen(SAFE_STR(sd_ctx_params->model_path)) > 0) { LOG_INFO("loading model from '%s'", sd_ctx_params->model_path); @@ -352,6 +783,115 @@ class StableDiffusionGGML { auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + // LTX-2 prefix + Gemma sandwich-norm fixup: the conditioner expects Gemma at + // `text_encoder.model.*`, but `--llm-path` prepends `text_encoders.llm.*` + // (convert_tensors_name then maps gguf llama names to HF names, yielding + // `text_encoders.llm.model.*`). + // + // Additionally, Gemma 3 has 4 layernorms per block (sandwich norms) that the + // shared llm_name_map only partly translates. The raw GGUF names blk.N.{attn_norm, + // post_attention_norm, ffn_norm, post_ffw_norm} end up as HF-style + // input_layernorm + post_attention_norm + post_attention_layernorm + post_ffw_norm + // after the generic map (where ffn_norm→post_attention_layernorm is Qwen-correct + // but wrong for Gemma). We rename here once version is LTX-2: + // post_attention_layernorm → pre_feedforward_layernorm (was actually ffn_norm) + // post_attention_norm → post_attention_layernorm (append _layernorm) + // post_ffw_norm → post_feedforward_layernorm + // Order matters: do the first rename first so the second can safely write to + // the now-vacated post_attention_layernorm slot. + if (sd_version_is_ltx2(version)) { + // Step 1: prefix rewrite text_encoders.llm. → text_encoder. + const std::string from = "text_encoders.llm."; + const std::string to = "text_encoder."; + { + String2TensorStorage renamed; + size_t renames = 0; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + std::string new_k = k; + if (k.rfind(from, 0) == 0) { + new_k = to + k.substr(from.size()); + kv.second.name = new_k; + renames++; + } + renamed[new_k] = std::move(kv.second); + } + if (renames > 0) { + tensor_storage_map.swap(renamed); + LOG_INFO("LTX-2: renamed %zu '%s*' tensors → '%s*' (Gemma text encoder path)", + renames, from.c_str(), to.c_str()); + } + } + + // Step 2: Gemma 3 sandwich-norm renames, applied in the order documented + // above. Each pass rebuilds the storage map because std::map keys are const. + auto rename_suffix = [&](const std::string& old_suffix, const std::string& new_suffix) -> size_t { + String2TensorStorage renamed; + size_t renames = 0; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + std::string new_k = k; + size_t p = k.rfind(old_suffix); + if (p != std::string::npos && p + old_suffix.size() == k.size()) { + // Only rename if prefix looks like a Gemma layer key. + if (k.find("text_encoder.model.layers.") != std::string::npos) { + new_k = k.substr(0, p) + new_suffix; + kv.second.name = new_k; + renames++; + } + } + renamed[new_k] = std::move(kv.second); + } + tensor_storage_map.swap(renamed); + return renames; + }; + size_t r1 = rename_suffix(".post_attention_layernorm.weight", ".pre_feedforward_layernorm.weight"); + size_t r2 = rename_suffix(".post_attention_norm.weight", ".post_attention_layernorm.weight"); + size_t r3 = rename_suffix(".post_ffw_norm.weight", ".post_feedforward_layernorm.weight"); + if (r1 + r2 + r3 > 0) { + LOG_INFO("LTX-2: Gemma sandwich-norm rename: %zu pre_ff, %zu post_attn, %zu post_ff", + r1, r2, r3); + } + + // Step 3: Duplicate `first_stage_model.per_channel_statistics.*` into the + // `first_stage_model.encoder.per_channel_statistics.*` path expected by + // VideoEncoder's child block tree. VideoDecoder also expects these under + // its `decoder.per_channel_statistics` subprefix. Real LTX-2 checkpoints + // only ship the top-level buffer (mean-of-means, std-of-means). + { + const std::string top_pre = "first_stage_model.per_channel_statistics."; + size_t copied = 0; + // Snapshot keys with top_pre first (iteration + insertion is unsafe). + std::vector> to_copy; + for (auto& kv : tensor_storage_map) { + const std::string& k = kv.first; + if (k.rfind(top_pre, 0) == 0) { + std::string suffix = k.substr(top_pre.size()); + to_copy.push_back({k, suffix}); + } + } + for (auto& pair : to_copy) { + const std::string& src_key = pair.first; + const std::string& suffix = pair.second; + auto src_it = tensor_storage_map.find(src_key); + if (src_it == tensor_storage_map.end()) continue; + for (const char* sub : {"encoder", "decoder"}) { + std::string dst_key = "first_stage_model." + std::string(sub) + + ".per_channel_statistics." + suffix; + if (tensor_storage_map.find(dst_key) != tensor_storage_map.end()) continue; + TensorStorage dup = src_it->second; + dup.name = dst_key; + tensor_storage_map[dst_key] = dup; + copied++; + } + } + if (copied > 0) { + LOG_INFO("LTX-2: duplicated %zu PerChannelStatistics entries to encoder/decoder subprefixes", + copied); + } + } + } + LOG_INFO("Version: %s ", model_version_to_str[version]); ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype @@ -361,6 +901,7 @@ class StableDiffusionGGML { model_loader.set_wtype_override(wtype, tensor_type_rules); } + std::map wtype_stat = model_loader.get_wtype_stat(); std::map conditioner_wtype_stat = model_loader.get_conditioner_wtype_stat(); std::map diffusion_model_wtype_stat = model_loader.get_diffusion_model_wtype_stat(); @@ -387,6 +928,113 @@ class StableDiffusionGGML { LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); + // ------------------------------------------------------------------ + // Auto-fit: compute per-component GPU/CPU placement plan based on + // currently free device memory. Runs before backend resolution so we + // can redirect the DiT backend and set per-component placement flags. + // Only affects the run when sd_ctx_params->auto_fit is true. + if (sd_ctx_params->auto_fit) { + backend_fit::ComputeReserves reserves; + if (sd_ctx_params->auto_fit_compute_reserve_dit_mb > 0) { + reserves.dit_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_dit_mb) * backend_fit::MiB; + } + if (sd_ctx_params->auto_fit_compute_reserve_vae_mb > 0) { + reserves.vae_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_vae_mb) * backend_fit::MiB; + } + if (sd_ctx_params->auto_fit_compute_reserve_cond_mb > 0) { + reserves.conditioner_bytes = + int64_t(sd_ctx_params->auto_fit_compute_reserve_cond_mb) * backend_fit::MiB; + } + + const int64_t alignment_guess = 256; + auto components = backend_fit::estimate_components( + model_loader, wtype, alignment_guess, reserves); + auto devices = backend_fit::enumerate_gpu_devices(); + int64_t margin_bytes = + int64_t(std::max(0, sd_ctx_params->auto_fit_target_mb)) * backend_fit::MiB; + const bool allow_split = sd_ctx_params->auto_tensor_split && + devices.size() >= 2 && + getenv("SD_CUDA_TENSOR_SPLIT") == nullptr; + auto plan = backend_fit::compute_plan(components, devices, margin_bytes, allow_split); + backend_fit::print_plan(plan, components, devices, margin_bytes); + + if (sd_ctx_params->auto_fit_dry_run) { + LOG_INFO("auto-fit: --fit-dry-run set, aborting init before loading models"); + return false; + } + + // Apply plan to fit_override. + fit_override.enabled = true; + auto dit_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::DIT); + auto vae_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::VAE); + auto cond_d = backend_fit::find_decision(plan, backend_fit::ComponentKind::CONDITIONER); + + if (dit_d) { + fit_override.dit_device_id = dit_d->device_id; + fit_override.dit_offload_params = + (dit_d->placement == backend_fit::Placement::GPU_OFFLOAD_PARAMS); + // Re-init the main backend if the chosen DiT device differs from + // whatever init_backend() picked. Keep `backend_device_id` in + // sync — it's what resolve_component_backend compares against. + const int current_dev = backend_device_id; + if (!ggml_backend_is_cpu(backend) && dit_d->placement == backend_fit::Placement::CPU) { + LOG_INFO("auto-fit: switching DiT backend from GPU %d to CPU", current_dev); + ggml_backend_free(backend); + backend = ggml_backend_cpu_init(); + backend_device_id = BACKEND_DEVICE_CPU; + } else if (dit_d->placement != backend_fit::Placement::CPU && + dit_d->device_id != current_dev) { + LOG_INFO("auto-fit: switching DiT backend from GPU %d to GPU %d", + current_dev, dit_d->device_id); + ggml_backend_free(backend); + backend = init_device_backend(dit_d->device_id, "DiT (auto-fit)"); + backend_device_id = dit_d->device_id; + } + } + if (vae_d) { + fit_override.vae_device_id = vae_d->device_id; + fit_override.vae_on_cpu = (vae_d->placement == backend_fit::Placement::CPU); + fit_override.vae_offload_params = + (vae_d->placement == backend_fit::Placement::GPU_OFFLOAD_PARAMS); + } + if (cond_d) { + fit_override.cond_device_id = cond_d->device_id; + fit_override.cond_on_cpu = (cond_d->placement == backend_fit::Placement::CPU); + fit_override.cond_offload_params = + (cond_d->placement == backend_fit::Placement::GPU_OFFLOAD_PARAMS); + } + + // Capture state for auto-VAE-tiling at gen time. We can't read + // sd_ctx_params after this scope, so store what we'll need. + auto_fit_enabled = true; + auto_fit_vae_compute_reserve_bytes = reserves.vae_bytes; + auto_fit_vae_on_cpu = vae_d && vae_d->placement == backend_fit::Placement::CPU; + + // If auto-fit placed any component in tensor-split mode, capture + // that here so init_tensor_split below configures the right flags. + const bool dit_split = dit_d && dit_d->placement == backend_fit::Placement::GPU_TENSOR_SPLIT; + const bool cond_split = cond_d && cond_d->placement == backend_fit::Placement::GPU_TENSOR_SPLIT; + if (dit_split) pending_split_dit = true; + if (cond_split) pending_split_cond = true; + + // For tensor-split components, the device id semantically means + // "main GPU + extras"; pin that to whatever main backend ended up + // active so resolve_component_backend doesn't try to forward to + // a non-main GPU. + if (dit_split) { + fit_override.dit_device_id = backend_device_id; + fit_override.dit_offload_params = false; + } + if (cond_split) { + fit_override.cond_device_id = backend_device_id; + fit_override.cond_on_cpu = false; + fit_override.cond_offload_params = false; + } + } + init_tensor_split(sd_ctx_params->auto_tensor_split); + if (sd_ctx_params->lora_apply_mode == LORA_APPLY_AUTO) { bool have_quantized_weight = false; if (wtype != GGML_TYPE_COUNT && ggml_is_quantized(wtype)) { @@ -426,19 +1074,76 @@ class StableDiffusionGGML { } bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; + if (fit_override.enabled && fit_override.cond_on_cpu) { + clip_on_cpu = true; + } + + // LTX-2 Gemma 3: default the text encoder to CPU. Per-layer F32 + // reduction-order disagreement between cuBLAS's tile reduction and + // CPU's SIMD horizontal reduction seeds ~6e-5 drift at the first + // RMSNorm and compounds across 48 transformer layers; the final + // hidden state ends up ~10× off in absolute terms, which is enough + // to alter prompt semantics in the LTX-2 conditioner (e.g. drop the + // subject — "cat on beach" → "person on beach"). Allow the user to + // override via SD_CUDA_DEVICE_CLIP=N or via auto-fit; warn loudly + // when they do, especially on quantized weights. + if (sd_version_is_ltx2(version) && !clip_on_cpu) { + const char* explicit_clip = getenv("SD_CUDA_DEVICE_CLIP"); + const bool user_set_cpu = (explicit_clip != nullptr && std::atoi(explicit_clip) < 0); + const bool user_set_cuda = (explicit_clip != nullptr && std::atoi(explicit_clip) >= 0); + const bool autofit_picks = fit_override.enabled && fit_override.cond_device_id >= 0; + if (user_set_cpu) { + clip_on_cpu = true; // explicit -1 → CPU, no message needed + } else if (!user_set_cuda && !autofit_picks) { + clip_on_cpu = true; + LOG_INFO("LTX-2: defaulting Gemma 3 text encoder to CPU " + "(CUDA path has cumulative F32 drift that can alter prompt semantics). " + "Set SD_CUDA_DEVICE_CLIP=N to run on CUDA device N anyway."); + } else { + auto q_proj_it = tensor_storage_map.find("text_encoder.model.layers.0.self_attn.q_proj.weight"); + const bool gemma_quantized = (q_proj_it != tensor_storage_map.end() && + ggml_is_quantized(q_proj_it->second.type)); + if (gemma_quantized) { + LOG_WARN("LTX-2: running QUANTIZED Gemma 3 text encoder on CUDA. " + "Cumulative F32 reduction-order drift across 48 layers can shift " + "the prompt embedding enough to lose subject/style cues " + "(e.g. \"cat on a beach\" → \"person on a beach\"). " + "Unset SD_CUDA_DEVICE_CLIP (or set it to -1) to use the CPU " + "encoder for full prompt fidelity."); + } else { + LOG_WARN("LTX-2: running Gemma 3 text encoder on CUDA. " + "Cumulative F32 reduction-order drift across 48 layers may alter " + "prompt semantics. Unset SD_CUDA_DEVICE_CLIP to use CPU."); + } + } + } + + // Per-component offload flags. `offload_params_to_cpu` (the user's + // global --offload-to-cpu) applies to every component. Auto-fit may + // additionally force DiT-only offload when the DiT doesn't fit in + // VRAM; that MUST NOT be propagated to the Conditioner/VAE, otherwise + // their weights get pinned in RAM and the system can OOM (e.g. an + // LTX-2 run pinning Gemma 9.5 GB + DiT 13 GB + VAE 1.4 GB in 32 GB RAM). + const bool dit_offload = offload_params_to_cpu || + (fit_override.enabled && fit_override.dit_offload_params); + const bool cond_offload = offload_params_to_cpu || + (fit_override.enabled && fit_override.cond_offload_params); + const bool vae_offload = offload_params_to_cpu || + (fit_override.enabled && fit_override.vae_offload_params); { - clip_backend = backend; - if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { - LOG_INFO("CLIP: Using CPU backend"); - clip_backend = ggml_backend_cpu_init(); - } + // Pick a device for the text-encoder stack. SD_CUDA_DEVICE_CLIP overrides + // (set to -1 for CPU); `keep_clip_on_cpu` still forces CPU regardless. + // When auto-fit is active, fit_override.cond_device_id wins. + clip_backend = resolve_component_backend( + "SD_CUDA_DEVICE_CLIP", "CLIP/TextEncoder", clip_on_cpu, + fit_override.enabled ? fit_override.cond_device_id : -2); if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; @@ -459,53 +1164,53 @@ class StableDiffusionGGML { } cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, sd_ctx_params->chroma_use_t5_mask, sd_ctx_params->chroma_t5_mask_pad); } else if (version == VERSION_OVIS_IMAGE) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version, "", false); } else { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); } diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_flux2(version)) { bool is_chroma = false; cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, true, 0, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.high_noise_diffusion_model", version); @@ -514,7 +1219,7 @@ class StableDiffusionGGML { diffusion_model->get_desc() == "Wan2.1-FLF2V-14B" || diffusion_model->get_desc() == "Wan2.1-I2V-1.3B") { clip_vision = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); @@ -525,44 +1230,67 @@ class StableDiffusionGGML { enable_vision = true; } cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version, "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version, sd_ctx_params->qwen_image_zero_cond_t); } else if (sd_version_is_anima(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model"); } else if (sd_version_is_z_image(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model", version); } else if (sd_version_is_ernie_image(version)) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "model.diffusion_model"); + } else if (sd_version_is_ltx2(version)) { + // LTX-2: Gemma 3 text encoder (Phase 8), 1D embeddings connector + DiT + // caption_projection (Phase 9), and LTX-2 causal 3D VAE (Phase 11) are all + // landed. LTX2GemmaConditioner auto-detects connector presence from the + // tensor map; if absent it falls back to Gemma's last_hidden_state. + // The tokenizer.json path is required — prompts can't be encoded without + // it. Any HuggingFace-format `tokenizer.json` for Gemma 3 works. + { + auto split_guard = begin_pending_split(tensor_split_state.split_cond); + cond_stage_model = std::make_shared(clip_backend, + cond_offload, + tensor_storage_map, + "text_encoder", + SAFE_STR(sd_ctx_params->gemma_tokenizer_path)); + } + { + auto split_guard = begin_pending_split(tensor_split_state.split_dit); + diffusion_model = std::make_shared(backend, + dit_offload, + tensor_storage_map, + "model.diffusion_model", + version); + } } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -570,20 +1298,20 @@ class StableDiffusionGGML { } if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, embbeding_map, version, PM_VERSION_2); } else { cond_stage_model = std::make_shared(clip_backend, - offload_params_to_cpu, + cond_offload, tensor_storage_map, embbeding_map, version); } diffusion_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, version); if (sd_ctx_params->diffusion_conv_direct) { @@ -592,11 +1320,40 @@ class StableDiffusionGGML { } } + // ---- Conditioner: optionally lazy-load the LLM (e.g. Gemma) ---- + // The connector stays eager because it's tiny relative to the LLM. + // Lazy mode skips adding LLM tensors to the global eager-load map + // and registers a callback that reads them from disk on first encode. + if (lazy_load_cond) { + std::map llm_lazy_tensors; + cond_stage_model->get_llm_param_tensors(llm_lazy_tensors); + int n_threads_local = n_threads; + bool enable_mmap = sd_ctx_params->enable_mmap; + cond_stage_model->set_llm_lazy_load([this, llm_lazy_tensors, n_threads_local, enable_mmap]() mutable { + return model_loader_->load_tensors(llm_lazy_tensors, load_ignore_tensors, n_threads_local, enable_mmap); + }); + size_t before = tensors.size(); + cond_stage_model->get_non_llm_param_tensors(tensors); + LOG_INFO("lazy_cond: %zu LLM tensors (lazy), +%zu non-LLM tensors (eager)", + llm_lazy_tensors.size(), tensors.size() - before); + } else { + cond_stage_model->get_param_tensors(tensors); + } cond_stage_model->alloc_params_buffer(); - cond_stage_model->get_param_tensors(tensors); + // ---- DiT: optionally lazy-load entirely (single inner runner) ---- + if (lazy_load_dit) { + std::map dit_lazy_tensors; + diffusion_model->get_param_tensors(dit_lazy_tensors); + int n_threads_local = n_threads; + bool enable_mmap = sd_ctx_params->enable_mmap; + diffusion_model->set_lazy_load([this, dit_lazy_tensors, n_threads_local, enable_mmap]() mutable { + return model_loader_->load_tensors(dit_lazy_tensors, load_ignore_tensors, n_threads_local, enable_mmap); + }); + } else { + diffusion_model->get_param_tensors(tensors); + } diffusion_model->alloc_params_buffer(); - diffusion_model->get_param_tensors(tensors); if (sd_version_is_unet_edit(version)) { vae_decode_only = false; @@ -607,19 +1364,23 @@ class StableDiffusionGGML { high_noise_diffusion_model->get_param_tensors(tensors); } - if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { - LOG_INFO("VAE Autoencoder: Using CPU backend"); - vae_backend = ggml_backend_cpu_init(); - } else { - vae_backend = backend; + // Pick a device for the VAE. SD_CUDA_DEVICE_VAE overrides (set to -1 for CPU); + // `keep_vae_on_cpu` still forces CPU regardless. Auto-fit, when active, + // supplies fit_override.vae_device_id which takes precedence over env. + bool vae_on_cpu = sd_ctx_params->keep_vae_on_cpu; + if (fit_override.enabled && fit_override.vae_on_cpu) { + vae_on_cpu = true; } + vae_backend = resolve_component_backend( + "SD_CUDA_DEVICE_VAE", "VAE", vae_on_cpu, + fit_override.enabled ? fit_override.vae_device_id : -2); auto create_tae = [&]() -> std::shared_ptr { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { return std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "decoder", vae_decode_only, @@ -627,7 +1388,7 @@ class StableDiffusionGGML { } else { auto model = std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "decoder.layers", vae_decode_only, @@ -641,14 +1402,38 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { return std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "first_stage_model", vae_decode_only, version); + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: in the real checkpoint after convert_tensors_name, + // the `vae.` → `first_stage_model.` rename from name_conversion.cpp + // puts weights under the standard `first_stage_model.` prefix. The + // sd-vae-parity test uses a pre-named `vae.` state dict directly so + // it can run on the parity dumper's output without going through the + // conversion pass. + // + // The 22B checkpoint (see `ltx2_22b_{enc,dec}_specs`) has a 9-block + // encoder/decoder with mixed RES_X and COMPRESS_* blocks — much deeper + // than the 4-block tiny-test default. We hardcode the 22B spec here for + // the smoke test; a proper auto-detect from tensor shapes is a follow-up. + return std::make_shared(vae_backend, + vae_offload, + tensor_storage_map, + "first_stage_model", + version, + /*in_ch=*/3, + /*latent_ch=*/128, + /*patch=*/4, + /*decoder_base_ch=*/128, + /*timestep_cond=*/false, + LTXVAE::LTX2VAERunner::ltx2_22b_enc_specs(), + LTXVAE::LTX2VAERunner::ltx2_22b_dec_specs()); } else { auto model = std::make_shared(vae_backend, - offload_params_to_cpu, + vae_offload, tensor_storage_map, "first_stage_model", vae_decode_only, @@ -671,7 +1456,7 @@ class StableDiffusionGGML { LOG_INFO("using FakeVAE"); first_stage_model = std::make_shared(version, vae_backend, - offload_params_to_cpu); + vae_offload); } else if (use_tae && !tae_preview_only) { LOG_INFO("using TAE for encoding / decoding"); first_stage_model = create_tae(); @@ -718,7 +1503,7 @@ class StableDiffusionGGML { if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { pmid_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "pmid", version, @@ -726,7 +1511,7 @@ class StableDiffusionGGML { LOG_INFO("using PhotoMaker Version 2"); } else { pmid_model = std::make_shared(backend, - offload_params_to_cpu, + dit_offload, tensor_storage_map, "pmid", version); @@ -835,6 +1620,11 @@ class StableDiffusionGGML { ignore_tensors.insert("text_encoders.llm.vision_tower."); ignore_tensors.insert("text_encoders.llm.multi_modal_projector."); } + // Stash ignore_tensors so lazy-load callbacks (registered earlier) can + // re-pass it when ModelLoader::load_tensors fires per-component. + load_ignore_tensors = ignore_tensors; + // Debug: dump tensor names containing video_embeddings_connector to verify + // they're registered (only in lazy_load_cond mode). bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); if (!success) { LOG_ERROR("load tensors from model loader failed"); @@ -960,6 +1750,8 @@ class StableDiffusionGGML { } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; + } else if (sd_version_is_ltx2(version)) { + pred_type = LTX2_FLOW_PRED; } else { pred_type = EPS_PRED; } @@ -992,6 +1784,11 @@ class StableDiffusionGGML { denoiser = std::make_shared(); break; } + case LTX2_FLOW_PRED: { + LOG_INFO("running in LTX-2 FLOW mode"); + denoiser = std::make_shared(); + break; + } default: { LOG_ERROR("Unknown predition type %i", pred_type); ggml_free(ctx); @@ -1625,6 +2422,11 @@ class StableDiffusionGGML { float cfg_scale = guidance.txt_cfg; float img_cfg_scale = guidance.img_cfg; float slg_scale = guidance.slg.scale; + float rescale_scale = guidance.rescale_scale; + float stg_scale = guidance.stg_scale; + std::vector stg_blocks(guidance.stg_blocks, + guidance.stg_blocks + guidance.stg_blocks_count); + bool has_stg = stg_scale != 0.f && !stg_blocks.empty(); sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version, cache_params, @@ -1686,6 +2488,7 @@ class StableDiffusionGGML { sd::Tensor uncond_out; sd::Tensor img_cond_out; sd::Tensor skip_cond_out; + sd::Tensor stg_cond_out; sd_sample::SampleStepCacheDispatcher step_cache(cache_runtime, step, sigma); std::vector> controls; DiffusionParams diffusion_params; @@ -1699,6 +2502,7 @@ class StableDiffusionGGML { diffusion_params.vace_context = vace_context.empty() ? nullptr : &vace_context; diffusion_params.vace_strength = vace_strength; diffusion_params.skip_layers = nullptr; + diffusion_params.stg_skip_blocks = nullptr; compute_sample_controls(control_image, noised_input, @@ -1707,14 +2511,16 @@ class StableDiffusionGGML { &controls); auto run_condition = [&](const SDCondition& condition, - const sd::Tensor* c_concat_override = nullptr, - const std::vector* local_skip_layers = nullptr) -> sd::Tensor { - diffusion_params.context = condition.c_crossattn.empty() ? nullptr : &condition.c_crossattn; - diffusion_params.c_concat = c_concat_override != nullptr ? c_concat_override : (condition.c_concat.empty() ? nullptr : &condition.c_concat); - diffusion_params.y = condition.c_vector.empty() ? nullptr : &condition.c_vector; - diffusion_params.t5_ids = condition.c_t5_ids.empty() ? nullptr : &condition.c_t5_ids; - diffusion_params.t5_weights = condition.c_t5_weights.empty() ? nullptr : &condition.c_t5_weights; - diffusion_params.skip_layers = local_skip_layers; + const sd::Tensor* c_concat_override = nullptr, + const std::vector* local_skip_layers = nullptr, + const std::vector* local_stg_skip_blocks = nullptr) -> sd::Tensor { + diffusion_params.context = condition.c_crossattn.empty() ? nullptr : &condition.c_crossattn; + diffusion_params.c_concat = c_concat_override != nullptr ? c_concat_override : (condition.c_concat.empty() ? nullptr : &condition.c_concat); + diffusion_params.y = condition.c_vector.empty() ? nullptr : &condition.c_vector; + diffusion_params.t5_ids = condition.c_t5_ids.empty() ? nullptr : &condition.c_t5_ids; + diffusion_params.t5_weights = condition.c_t5_weights.empty() ? nullptr : &condition.c_t5_weights; + diffusion_params.skip_layers = local_skip_layers; + diffusion_params.stg_skip_blocks = local_stg_skip_blocks; sd::Tensor cached_output; if (step_cache.before_condition(&condition, noised_input, &cached_output)) { @@ -1780,6 +2586,20 @@ class StableDiffusionGGML { } } + // STG (Spatio-Temporal Guidance): third forward pass with self-attention + // skipped on stg_blocks. The "weakened" prediction is mixed into the + // guided pred: pred += stg_scale * (cond - perturbed). Reference: + // ltx_core/components/guiders.py::calculate (perturbed term). + if (has_stg && !step_cache.is_step_skipped()) { + stg_cond_out = run_condition(cond, + cond.c_concat.empty() ? nullptr : &cond.c_concat, + /*local_skip_layers=*/nullptr, + &stg_blocks); + if (stg_cond_out.empty()) { + return {}; + } + } + GGML_ASSERT(!cond_out.empty()); sd::Tensor latent_result = cond_out; if (!uncond_out.empty()) { @@ -1797,7 +2617,45 @@ class StableDiffusionGGML { if (is_skiplayer_step && !skip_cond_out.empty()) { latent_result += (cond_out - skip_cond_out) * slg_scale; } + + // STG perturbed-pass mixing: pred += stg_scale * (cond - perturbed). + if (has_stg && !stg_cond_out.empty()) { + latent_result += (cond_out - stg_cond_out) * stg_scale; + } + denoised = latent_result * c_out + x * c_skip; + + // CFG-rescale: pull pred.std() back toward cond.std() to combat oversaturation + // introduced by CFG amplitude. Reference: ltx_core/components/guiders.py::calculate. + // Python operates on DENOISED predictions (X0Model returns denoised, then guider + // does CFG on denoised) — so we must compute std and multiply on the denoised + // tensor, not on the velocity-domain `latent_result`. Skip when rescale_scale==0 + // (default for non-LTX-2 models) or when only cond_out is present (no CFG mix + // happened — pred would equal cond_only and rescale would be a no-op). + if (rescale_scale != 0.f && (!uncond_out.empty() || !img_cond_out.empty())) { + auto t_std = [](const float* d, int64_t n) -> double { + if (n <= 1) return 0.0; + double s = 0.0, sq = 0.0; + for (int64_t i = 0; i < n; ++i) { + double v = static_cast(d[i]); + s += v; + sq += v * v; + } + double mean = s / n; + double var = sq / n - mean * mean; + return std::sqrt(std::max(0.0, var)); + }; + // Denoised(cond_alone) = c_out * cond_out + c_skip * x. Materialize it just + // for the std computation; we don't want to apply CFG to it. + sd::Tensor denoised_cond_only = cond_out * c_out + x * c_skip; + double cond_std = t_std(denoised_cond_only.data(), denoised_cond_only.numel()); + double pred_std = t_std(denoised.data(), denoised.numel()); + if (pred_std > 1e-12) { + double factor = cond_std / pred_std; + factor = rescale_scale * factor + (1.0 - rescale_scale); + denoised *= static_cast(factor); + } + } if (cache_runtime.spectrum_enabled) { cache_runtime.spectrum.update(denoised); } @@ -1865,6 +2723,9 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_uses_flux2_vae(version)) { latent_channel = 128; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE latent dim (matches DiT patchify_proj in_channels). + latent_channel = 128; } else { latent_channel = 16; } @@ -1872,9 +2733,24 @@ class StableDiffusionGGML { return latent_channel; } - int get_image_seq_len(int h, int w) { + int get_image_seq_len(int h, int w, int frames = 1) { int vae_scale_factor = get_vae_scale_factor(); - return (h / vae_scale_factor) * (w / vae_scale_factor); + int spatial_tokens = (h / vae_scale_factor) * (w / vae_scale_factor); + // For video flow-match schedulers (LTX-2, Wan), `tokens` in the shift + // formula is math.prod(latent.shape[2:]) = T_latent * H_latent * W_latent. + // Earlier we only passed the spatial count (H*W), which under-shifted + // the LTX-2 schedule because the 22B run has 25-frame inputs → + // T_latent = 4, so the real token count is 4× the spatial count. + // Python reference: ltx_core/components/schedulers.py::LTX2Scheduler.execute. + if (frames > 1 && sd_version_is_ltx2(version)) { + int T_latent = ((frames - 1) / 8) + 1; // LTX-2 VAE: 8× temporal compression. + return spatial_tokens * T_latent; + } + if (frames > 1 && sd_version_is_wan(version)) { + int T_latent = ((frames - 1) / 4) + 1; // Wan VAE: 4× temporal compression. + return spatial_tokens * T_latent; + } + return spatial_tokens; } sd::Tensor generate_init_latent(int width, @@ -1887,6 +2763,9 @@ class StableDiffusionGGML { int T = frames; if (sd_version_is_wan(version)) { T = ((T - 1) / 4) + 1; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: 8× temporal compression. + T = ((T - 1) / 8) + 1; } int C = get_latent_channel(); if (video) { @@ -2050,6 +2929,7 @@ const char* prediction_to_str[] = { "sd3_flow", "flux_flow", "flux2_flow", + "ltx2_flow", }; const char* sd_prediction_name(enum prediction_t prediction) { @@ -2162,6 +3042,16 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_t5_mask_pad = 1; + + sd_ctx_params->auto_fit = true; + sd_ctx_params->auto_fit_target_mb = 512; + sd_ctx_params->auto_fit_dry_run = false; + sd_ctx_params->auto_fit_compute_reserve_dit_mb = 0; + sd_ctx_params->auto_fit_compute_reserve_vae_mb = 0; + sd_ctx_params->auto_fit_compute_reserve_cond_mb = 0; + sd_ctx_params->lazy_load_dit = true; + sd_ctx_params->lazy_load_cond = true; + sd_ctx_params->auto_tensor_split = true; } char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { @@ -2178,6 +3068,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "t5xxl_path: %s\n" "llm_path: %s\n" "llm_vision_path: %s\n" + "gemma_tokenizer_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" "vae_path: %s\n" @@ -2210,6 +3101,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->llm_path), SAFE_STR(sd_ctx_params->llm_vision_path), + SAFE_STR(sd_ctx_params->gemma_tokenizer_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), @@ -2244,6 +3136,10 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) { sample_params->guidance.txt_cfg = 7.0f; sample_params->guidance.img_cfg = INFINITY; sample_params->guidance.distilled_guidance = 3.5f; + sample_params->guidance.rescale_scale = 0.f; // LTX-2.3 expects 0.7 + sample_params->guidance.stg_scale = 0.f; // LTX-2.3 expects 1.0 with stg_blocks=[28] + sample_params->guidance.stg_blocks = nullptr; + sample_params->guidance.stg_blocks_count = 0; sample_params->guidance.slg.layer_count = 0; sample_params->guidance.slg.layer_start = 0.01f; sample_params->guidance.slg.layer_end = 0.2f; @@ -2382,6 +3278,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; + sd_vid_gen_params->fps = 24.f; sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_cache_params_init(&sd_vid_gen_params->cache); } @@ -2510,6 +3407,23 @@ static float resolve_eta(sd_ctx_t* sd_ctx, return eta; } +// Mirrors the Python LTX-2 reference's DEFAULT_NEGATIVE_PROMPT +// (ltx-pipelines/utils/constants.py:135). Used when the caller does not pass +// a negative prompt for an LTX-2 video gen — empty negative + CFG ≥ 5 was +// observed to over-push attention into broken/dark scenes for some seeds. +static const char* LTX2_DEFAULT_NEGATIVE_PROMPT = + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."; + struct GenerationRequest { std::string prompt; std::string negative_prompt; @@ -2535,6 +3449,7 @@ struct GenerationRequest { sd_guidance_params_t high_noise_guidance = {}; sd_pm_params_t pm_params = {}; int frames = -1; + float fps = 0.f; // 0 = keep diffusion model's default float vace_strength = 1.f; GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { @@ -2562,9 +3477,30 @@ struct GenerationRequest { GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { prompt = SAFE_STR(sd_vid_gen_params->prompt); negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); + const SDVersion version = sd_ctx->sd->version; + const bool is_ltx2 = sd_version_is_ltx2(version); + // LTX-2: default to the curated negative prompt from the Python + // reference (ltx-pipelines/utils/constants.py:135) when the caller + // didn't supply one. Empty negative + CFG ≥ 5 over-pushes attention + // and produces dark/distorted scenes for some seeds. + if (is_ltx2 && negative_prompt.empty()) { + negative_prompt = LTX2_DEFAULT_NEGATIVE_PROMPT; + LOG_INFO("LTX-2: using default negative prompt (caller passed empty). " + "Pass --negative-prompt to override."); + } width = sd_vid_gen_params->width; height = sd_vid_gen_params->height; - frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1; + // LTX-2's VAE has 8× temporal compression, so the output frame count + // must satisfy (frames - 1) %% 8 == 0; other video models (Wan etc.) + // use 4× compression. Snap DOWN to the nearest valid value. + const int frame_stride = is_ltx2 ? 8 : 4; + const int requested_frames = sd_vid_gen_params->video_frames; + frames = (requested_frames - 1) / frame_stride * frame_stride + 1; + if (frames != requested_frames) { + LOG_WARN("%s: requested %d frames is not (N - 1) %% %d == 0; snapping to %d", + is_ltx2 ? "LTX-2" : "video", requested_frames, frame_stride, frames); + } + fps = sd_vid_gen_params->fps; clip_skip = sd_vid_gen_params->clip_skip; vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); @@ -2716,7 +3652,7 @@ struct SamplePlan { sample_params->scheduler, sample_method); sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, - sd_ctx->sd->get_image_seq_len(request->height, request->width), + sd_ctx->sd->get_image_seq_len(request->height, request->width, request->frames), scheduler, sd_ctx->sd->version); } @@ -3157,6 +4093,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_img_gen_params); + { + const int vsf = std::max(1, request.vae_scale_factor); + const int64_t lw = request.width / vsf; + const int64_t lh = request.height / vsf; + sd_ctx->sd->maybe_auto_set_vae_tiling(lw, lh, /*t_latent=*/1); + } LOG_INFO("generate_image %dx%d", request.width, request.height); sd_ctx->sd->rng->manual_seed(request.seed); @@ -3489,6 +4431,40 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, sd::Tensor vid = sd_ctx->sd->decode_first_stage(final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); + // Diagnostic (gated by SD_DUMP_GEN_STATS=1): post-decode video stats. + // Should be in [0, 1] after the scale_output rescale; mean ~0.3-0.5 for a + // natural image. If mean is near 0, decoder output is being clamped. + if (std::getenv("SD_DUMP_GEN_STATS") && !vid.empty()) { + const auto& sh = vid.shape(); + const float* d = vid.data(); + int64_t n = vid.numel(); + double s = 0, sq = 0; float lo = d[0], hi = d[0]; + for (int64_t i = 0; i < n; ++i) { + double v = d[i]; s += v; sq += v*v; + if (d[i] < lo) lo = d[i]; if (d[i] > hi) hi = d[i]; + } + double m = s / n; double v = sq / n - m * m; + std::printf("=== decoded video stats: shape=["); + for (size_t i = 0; i < sh.size(); ++i) std::printf("%s%lld", i ? "," : "", (long long)sh[i]); + std::printf("] overall mean=%.3f std=%.3f min=%.3f max=%.3f ===\n", + m, std::sqrt(std::max(0.0, v)), lo, hi); + // Per-channel breakdown if 5D [W,H,T,C,B] or 4D [W,H,T,C] + if (sh.size() >= 4) { + int64_t W = sh[0], H = sh[1], T = sh[2], C = sh[3]; + for (int64_t c = 0; c < std::min(C, 3); ++c) { + double cs = 0, csq = 0; int64_t cn = W * H * T; + for (int64_t t = 0; t < T; ++t) + for (int64_t h = 0; h < H; ++h) + for (int64_t w = 0; w < W; ++w) { + double vv = d[((c * T + t) * H + h) * W + w]; + cs += vv; csq += vv*vv; + } + double cm = cs / cn; double cv = csq / cn - cm * cm; + std::printf(" channel %lld: mean=%.3f std=%.3f\n", (long long)c, + cm, std::sqrt(std::max(0.0, cv))); + } + } + } if (sd_ctx->sd->free_params_immediately) { sd_ctx->sd->first_stage_model->free_params_buffer(); } @@ -3522,11 +4498,32 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_vid_gen_params); + { + const int vsf = std::max(1, request.vae_scale_factor); + const int64_t lw = request.width / vsf; + const int64_t lh = request.height / vsf; + // Temporal compression: WAN /4, LTX-2 /8, otherwise no compression. + // GenerationRequest::frames already reflects the user request; the + // latent T is computed the same way StableDiffusionGGML does it. + int64_t t_latent = request.frames; + if (sd_version_is_wan(sd_ctx->sd->version)) t_latent = ((t_latent - 1) / 4) + 1; + else if (sd_version_is_ltx2(sd_ctx->sd->version)) t_latent = ((t_latent - 1) / 8) + 1; + sd_ctx->sd->maybe_auto_set_vae_tiling(lw, lh, t_latent); + } sd_ctx->sd->rng->manual_seed(request.seed); sd_ctx->sd->sampler_rng->manual_seed(request.seed); sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift); sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count); + // Propagate output fps to diffusion models that need it for temporal RoPE + // (LTX-2 divides time positions by fps; see LTXRope::gen_video_positions). + if (request.fps > 0.f && sd_ctx->sd->diffusion_model) { + sd_ctx->sd->diffusion_model->set_fps(request.fps); + } + if (request.fps > 0.f && sd_ctx->sd->high_noise_diffusion_model) { + sd_ctx->sd->high_noise_diffusion_model->set_fps(request.fps); + } + SamplePlan plan(sd_ctx, sd_vid_gen_params, request); auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request); if (!latent_inputs_opt.has_value()) { @@ -3640,6 +4637,45 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t latent_end = ggml_time_ms(); LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); + // Diagnostic (gated by SD_DUMP_GEN_STATS=1): per-channel mean/std of the + // final latent. Useful when VAE output looks off — confirms the latent is + // in-distribution (~zero mean, unit std per channel post-PCS-normalize). + if (std::getenv("SD_DUMP_GEN_STATS")) { + const auto& sh = final_latent.shape(); + if (sh.size() >= 4) { + int64_t W = sh[0], H = sh[1], T = sh[2], C = sh[3]; + const float* d = final_latent.data(); + double overall_s = 0, overall_sq = 0; + int64_t overall_n = 0; + std::printf("=== final_latent stats (W=%lld H=%lld T=%lld C=%lld) ===\n", + (long long)W, (long long)H, (long long)T, (long long)C); + std::printf("first 4 channels (mean/std):"); + for (int64_t c = 0; c < std::min(C, 4); ++c) { + double s = 0, sq = 0; int64_t n = W * H * T; + for (int64_t t = 0; t < T; ++t) + for (int64_t h = 0; h < H; ++h) + for (int64_t w = 0; w < W; ++w) { + double v = d[((c * T + t) * H + h) * W + w]; + s += v; sq += v*v; + } + double mean = s / n; + double var = sq / n - mean * mean; + std::printf(" c%lld=%.3f/%.3f", (long long)c, mean, std::sqrt(std::max(0.0, var))); + } + for (int64_t i = 0; i < final_latent.numel(); ++i) { + double v = d[i]; overall_s += v; overall_sq += v*v; + } + overall_n = final_latent.numel(); + double om = overall_s / overall_n; + double ov = overall_sq / overall_n - om * om; + std::printf("\noverall: mean=%.3f std=%.3f min/max=", om, std::sqrt(std::max(0.0, ov))); + const float* d2 = final_latent.data(); + float lo = d2[0], hi = d2[0]; + for (int64_t i = 1; i < overall_n; ++i) { if (d2[i] < lo) lo = d2[i]; if (d2[i] > hi) hi = d2[i]; } + std::printf("%.3f/%.3f\n", lo, hi); + } + } + auto result = decode_video_outputs(sd_ctx, final_latent, num_frames_out); if (result == nullptr) { return nullptr; diff --git a/src/tokenizers/gemma_tokenizer.cpp b/src/tokenizers/gemma_tokenizer.cpp new file mode 100644 index 000000000..43c772abe --- /dev/null +++ b/src/tokenizers/gemma_tokenizer.cpp @@ -0,0 +1,254 @@ +#include "gemma_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "util.h" + +namespace { + +// Parse "<0xAB>" -> byte value 0xAB. Returns -1 if not a byte token. +int parse_byte_token(const std::string& piece) { + if (piece.size() != 6) return -1; + if (piece[0] != '<' || piece[1] != '0' || piece[2] != 'x' || piece[5] != '>') return -1; + auto hex = [](char c) -> int { + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'A' && c <= 'F') return 10 + c - 'A'; + if (c >= 'a' && c <= 'f') return 10 + c - 'a'; + return -1; + }; + int hi = hex(piece[3]), lo = hex(piece[4]); + if (hi < 0 || lo < 0) return -1; + return (hi << 4) | lo; +} + +} // namespace + +GemmaTokenizer::GemmaTokenizer() { + byte_fallback_ids_.fill(-1); + BOS_TOKEN = ""; + EOS_TOKEN = ""; + PAD_TOKEN = ""; + UNK_TOKEN = ""; + BOS_TOKEN_ID = 2; + EOS_TOKEN_ID = 1; + PAD_TOKEN_ID = 0; + UNK_TOKEN_ID = 3; + add_bos_token = true; // Gemma post-processor prepends . + add_eos_token = false; + pad_left = true; // Gemma uses left padding. +} + +std::string GemmaTokenizer::decode_token(int token_id) const { + if (token_id >= 0 && token_id < (int)id_to_piece_.size()) { + return id_to_piece_[token_id]; + } + return ""; +} + +// HF normalizer: Replace " " -> "▁" (U+2581). All other chars untouched. +std::string GemmaTokenizer::normalize(const std::string& text) const { + static const std::string metaspace = "\xe2\x96\x81"; // UTF-8 for U+2581 + std::string out; + out.reserve(text.size() + text.size() / 8); + for (char c : text) { + if (c == ' ') { + out.append(metaspace); + } else { + out.push_back(c); + } + } + return out; +} + +std::vector GemmaTokenizer::split_utf8_chars(const std::string& s) { + std::vector out; + size_t i = 0; + while (i < s.size()) { + unsigned char b = (unsigned char)s[i]; + size_t len; + if (b < 0x80) len = 1; + else if (b < 0xC0) len = 1; // malformed continuation; treat as 1-byte + else if (b < 0xE0) len = 2; + else if (b < 0xF0) len = 3; + else len = 4; + if (i + len > s.size()) len = s.size() - i; + out.emplace_back(s.substr(i, len)); + i += len; + } + return out; +} + +void GemmaTokenizer::byte_fallback(const std::string& ch, std::vector& out) const { + for (unsigned char b : ch) { + int id = byte_fallback_ids_[b]; + if (id >= 0) { + out.push_back(id_to_piece_[id]); // "<0xNN>" + } else { + out.push_back(UNK_TOKEN); + } + } +} + +std::vector GemmaTokenizer::bpe(std::vector pieces) const { + // Greedy BPE: at each step find the adjacent pair with lowest merge rank and apply it. + // O(N^2 * merges_lookup) per encode. N here is chars in a single chunk — a few hundred + // at most for our use. Good enough. + while (pieces.size() > 1) { + int best_rank = INT_MAX; + int best_i = -1; + for (size_t i = 0; i + 1 < pieces.size(); ++i) { + std::string key = pieces[i]; + key.push_back('\t'); + key.append(pieces[i + 1]); + auto it = merge_ranks_.find(key); + if (it != merge_ranks_.end() && it->second < best_rank) { + best_rank = it->second; + best_i = (int)i; + } + } + if (best_i < 0) break; + pieces[best_i] = pieces[best_i] + pieces[best_i + 1]; + pieces.erase(pieces.begin() + best_i + 1); + } + return pieces; +} + +std::vector GemmaTokenizer::encode(const std::string& text, on_new_token_cb_t /*cb*/) { + if (!loaded_) { + LOG_ERROR("GemmaTokenizer::encode called before load_from_file()"); + return {}; + } + + std::string normalized = normalize(text); + + // ignore_merges=true: if the entire (post-normalization) chunk is directly in vocab, + // emit it as a single token without running BPE. + if (ignore_merges_) { + auto it = vocab_.find(normalized); + if (it != vocab_.end()) { + return {it->second}; + } + } + + std::vector pieces; + pieces.reserve(normalized.size()); + for (const auto& ch : split_utf8_chars(normalized)) { + auto it = vocab_.find(ch); + if (it != vocab_.end()) { + pieces.push_back(ch); + } else { + byte_fallback(ch, pieces); + } + } + + pieces = bpe(std::move(pieces)); + + std::vector ids; + ids.reserve(pieces.size()); + for (const auto& p : pieces) { + auto it = vocab_.find(p); + if (it != vocab_.end()) { + ids.push_back(it->second); + } else { + ids.push_back(UNK_TOKEN_ID); + } + } + return ids; +} + +bool GemmaTokenizer::load_from_file(const std::string& path) { + std::ifstream f(path); + if (!f.is_open()) { + LOG_ERROR("GemmaTokenizer: cannot open %s", path.c_str()); + return false; + } + nlohmann::json j; + try { + f >> j; + } catch (const nlohmann::json::parse_error& e) { + LOG_ERROR("GemmaTokenizer: JSON parse error in %s: %s", path.c_str(), e.what()); + return false; + } + + if (!j.contains("model") || !j["model"].contains("vocab") || !j["model"].contains("merges")) { + LOG_ERROR("GemmaTokenizer: JSON missing model.vocab or model.merges"); + return false; + } + const auto& model = j["model"]; + + // Vocab: HF tokenizer.json for BPE stores vocab as an object {piece: id}. + const auto& vocab = model["vocab"]; + id_to_piece_.clear(); + id_to_piece_.resize(vocab.size()); + vocab_.reserve(vocab.size() * 2); + for (auto it = vocab.begin(); it != vocab.end(); ++it) { + const std::string piece = it.key(); + int id = it.value().get(); + vocab_.emplace(piece, id); + if (id >= 0 && id < (int)id_to_piece_.size()) { + id_to_piece_[id] = piece; + } + } + + // Merges: ordered list; earlier entries have higher priority (lower rank). + const auto& merges = model["merges"]; + merge_ranks_.reserve(merges.size() * 2); + int rank = 0; + for (const auto& m : merges) { + // tokenizers >=0.20 stores each merge as a [left, right] array; older versions used a + // single space-separated string. Accept both for robustness. + std::string left, right; + if (m.is_array() && m.size() == 2) { + left = m[0].get(); + right = m[1].get(); + } else if (m.is_string()) { + const std::string s = m.get(); + auto pos = s.find(' '); + if (pos == std::string::npos) continue; + left = s.substr(0, pos); + right = s.substr(pos + 1); + } else { + continue; + } + std::string key = left; + key.push_back('\t'); + key.append(right); + merge_ranks_.emplace(std::move(key), rank++); + } + + // Locate byte-fallback IDs. Every byte value 0..255 should have a "<0xNN>" entry. + for (int id = 0; id < (int)id_to_piece_.size(); ++id) { + int b = parse_byte_token(id_to_piece_[id]); + if (b >= 0) { + byte_fallback_ids_[b] = id; + } + } + + // Special token IDs: honor what's actually in the JSON if model is unusual; otherwise + // keep the Gemma-3 defaults from the ctor. + if (model.contains("unk_token") && model["unk_token"].is_string()) { + auto it = vocab_.find(model["unk_token"].get()); + if (it != vocab_.end()) UNK_TOKEN_ID = it->second; + } + if (j.contains("added_tokens")) { + for (const auto& at : j["added_tokens"]) { + if (!at.contains("content") || !at.contains("id")) continue; + const std::string c = at["content"].get(); + int id = at["id"].get(); + if (c == "") BOS_TOKEN_ID = id; + else if (c == "") EOS_TOKEN_ID = id; + else if (c == "") PAD_TOKEN_ID = id; + else if (c == "") UNK_TOKEN_ID = id; + } + } + + ignore_merges_ = model.value("ignore_merges", true); + loaded_ = true; + LOG_DEBUG("GemmaTokenizer loaded: vocab=%zu merges=%zu", vocab_.size(), merge_ranks_.size()); + return true; +} diff --git a/src/tokenizers/gemma_tokenizer.h b/src/tokenizers/gemma_tokenizer.h new file mode 100644 index 000000000..8753cbd9f --- /dev/null +++ b/src/tokenizers/gemma_tokenizer.h @@ -0,0 +1,50 @@ +#ifndef __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ +#define __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +#include "tokenizer.h" + +// Gemma 3 tokenizer. BPE with byte-fallback + Metaspace-style normalization +// (space → U+2581 "▁"). Loads a HuggingFace tokenizer.json produced by +// `AutoTokenizer.from_pretrained("google/gemma-3-12b-it").backend_tokenizer.save()`. +// +// Not embeddable as a header like the other tokenizers — the raw JSON is ~33 MB +// and the vocab alone is 262144 pieces plus 514906 merges. Expected workflow: +// ship the tokenizer.json file alongside the weights, pass its path at runtime. +class GemmaTokenizer : public Tokenizer { +protected: + std::unordered_map vocab_; // piece -> id + std::vector id_to_piece_; // id -> piece + std::unordered_map merge_ranks_; // "left\tright" -> rank (lower = higher priority) + std::array byte_fallback_ids_{}; // byte value -> piece id for <0xXX> + bool loaded_ = false; + bool ignore_merges_ = true; + + std::string decode_token(int token_id) const override; + std::string normalize(const std::string& text) const override; + + // Split a UTF-8 string into its individual code-point-sized chunks. + static std::vector split_utf8_chars(const std::string& s); + + // Byte-fallback a character that isn't in vocab: produce UTF-8 byte tokens. + void byte_fallback(const std::string& ch, std::vector& out) const; + + // Run BPE merging until no more merges apply. + std::vector bpe(std::vector pieces) const; + +public: + GemmaTokenizer(); + + bool load_from_file(const std::string& path); + bool is_loaded() const { return loaded_; } + int vocab_size() const { return (int)id_to_piece_.size(); } + + std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ diff --git a/src/vae.hpp b/src/vae.hpp index dc69535e8..6efc59cf6 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -7,7 +7,12 @@ struct VAE : public GGMLRunner { protected: SDVersion version; + // scale_input: encode-time [0,1]→[-1,1] of the user-provided RGB image. + // scale_output: decode-time [-1,1]→[0,1] of the decoder's RGB output. + // These are independent: LTX-2 takes [-1,1] inputs (no encode scaling) but + // still produces [-1,1] outputs that callers expect mapped to [0,1]. bool scale_input = true; + bool scale_output = true; virtual sd::Tensor _compute(const int n_threads, const sd::Tensor& z, bool decode_graph) = 0; @@ -73,6 +78,9 @@ struct VAE : public GGMLRunner { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1; + } else if (sd_version_is_ltx2(version)) { + // LTX-2 VAE: 32× spatial compression (256×256 → 8×8 latent). + scale_factor = 32; } return scale_factor; } @@ -199,7 +207,7 @@ struct VAE : public GGMLRunner { LOG_ERROR("vae decode compute failed"); return {}; } - if (scale_input) { + if (scale_output) { scale_tensor_to_0_1(&output); } int64_t t1 = ggml_time_ms(); diff --git a/tests/ltx_parity/CMakeLists.txt b/tests/ltx_parity/CMakeLists.txt new file mode 100644 index 000000000..2393a149c --- /dev/null +++ b/tests/ltx_parity/CMakeLists.txt @@ -0,0 +1,134 @@ +set(TARGET sd-ltx-parity) + +add_executable(${TARGET} + test_ltx_parity.cpp +) + +target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_TARGET sd-gemma-parity) + +add_executable(${GEMMA_TARGET} + test_gemma_parity.cpp +) + +target_link_libraries(${GEMMA_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_TOK_TARGET sd-gemma-tokenizer-test) + +add_executable(${GEMMA_TOK_TARGET} + test_gemma_tokenizer.cpp +) + +target_link_libraries(${GEMMA_TOK_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_TOK_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(S2D_TARGET sd-s2d-primitives-test) + +add_executable(${S2D_TARGET} + test_s2d_primitives.cpp +) + +target_link_libraries(${S2D_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${S2D_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(VAE_TARGET sd-vae-parity) + +add_executable(${VAE_TARGET} + test_vae_parity.cpp +) + +target_link_libraries(${VAE_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${VAE_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(CONN_TARGET sd-connector-parity) + +add_executable(${CONN_TARGET} + test_connector_parity.cpp +) + +target_link_libraries(${CONN_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${CONN_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(VAE_RT_TARGET sd-ltx2-vae-roundtrip) + +add_executable(${VAE_RT_TARGET} + test_ltx2_vae_roundtrip.cpp +) + +target_link_libraries(${VAE_RT_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${VAE_RT_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(GEMMA_CC_TARGET sd-gemma-cpu-vs-cuda) + +add_executable(${GEMMA_CC_TARGET} + test_gemma_cpu_vs_cuda.cpp +) + +target_link_libraries(${GEMMA_CC_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${GEMMA_CC_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(MM_PARITY_TARGET sd-mm-f32-parity) + +add_executable(${MM_PARITY_TARGET} + test_mm_f32_parity.cpp +) + +target_link_libraries(${MM_PARITY_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${MM_PARITY_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(CONT_PARITY_TARGET sd-cont-parity) + +add_executable(${CONT_PARITY_TARGET} + test_cont_parity.cpp +) + +target_link_libraries(${CONT_PARITY_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${CONT_PARITY_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(SOFTMAX_PARITY_TARGET sd-softmax-parity) + +add_executable(${SOFTMAX_PARITY_TARGET} + test_softmax_parity.cpp +) + +target_link_libraries(${SOFTMAX_PARITY_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${SOFTMAX_PARITY_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(ATTN_CHAIN_TARGET sd-attn-chain-parity) + +add_executable(${ATTN_CHAIN_TARGET} + test_attn_chain_parity.cpp +) + +target_link_libraries(${ATTN_CHAIN_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${ATTN_CHAIN_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(AV_BLOCK_SMOKE_TARGET sd-av-block-smoke) + +add_executable(${AV_BLOCK_SMOKE_TARGET} + test_av_block_smoke.cpp +) + +target_link_libraries(${AV_BLOCK_SMOKE_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${AV_BLOCK_SMOKE_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(AV_BLOCK_PARITY_TARGET sd-av-block-parity) + +add_executable(${AV_BLOCK_PARITY_TARGET} + test_av_block_parity.cpp +) + +target_link_libraries(${AV_BLOCK_PARITY_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${AV_BLOCK_PARITY_TARGET} PUBLIC c_std_11 cxx_std_17) + +set(AV_MODEL_PARITY_TARGET sd-av-model-parity) + +add_executable(${AV_MODEL_PARITY_TARGET} + test_av_model_parity.cpp +) + +target_link_libraries(${AV_MODEL_PARITY_TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${AV_MODEL_PARITY_TARGET} PUBLIC c_std_11 cxx_std_17) diff --git a/tests/ltx_parity/README.md b/tests/ltx_parity/README.md new file mode 100644 index 000000000..8f5f090fa --- /dev/null +++ b/tests/ltx_parity/README.md @@ -0,0 +1,36 @@ +# LTX-2 parity tests + +Block-by-block numerical parity between the C++/GGML LTX-2 port and the reference +PyTorch implementation in `/devel/tools/diffusion/LTX-2/packages/ltx-core/`. + +## How it works + +`dump_reference.py` instantiates a **tiny, deterministic** LTX-2 transformer with fixed +random weights (seed=0) and runs a forward pass on a fixed input. It writes: + +- `/tmp/ltx_ref/manifest.json` — catalogue of every dumped tensor (name, shape, dtype, offset) +- `/tmp/ltx_ref/state_dict.safetensors` — all model weights in a standard format +- `/tmp/ltx_ref/tensors/*.bin` — each intermediate tensor as raw float32 bytes + +The "tiny" model is small enough (2 layers, inner_dim=128) to run in milliseconds on CPU +and make it easy to dump every intermediate without filling the disk. That scope is +deliberate: parity at tiny dims transfers to full-size models because every block is +tested exhaustively. + +A matching C++ test (to be written) loads `state_dict.safetensors`, replays the same +input, and diffs every intermediate tensor against the reference. Tolerances: +- F32: 1e-5 absolute, 1e-4 relative +- BF16/FP16 C++ path: 1e-2 absolute, 5e-3 relative + +## Run + +```bash +/home/ilintar/venv/bin/python dump_reference.py +``` + +## What's NOT covered (yet) + +- **Gemma 3 text encoder** — needs a Gemma 3 checkpoint. Deferred; we dump a synthetic + `context` tensor (random but fixed) as a placeholder. +- **VAE** — separate dumper planned once the C++ VAE is building. +- **Sampler loop** — a separate script, not this one. This one tests a single forward call. diff --git a/tests/ltx_parity/dump_av_block.py b/tests/ltx_parity/dump_av_block.py new file mode 100644 index 000000000..13691e185 --- /dev/null +++ b/tests/ltx_parity/dump_av_block.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +"""Dump LTX-2 AV transformer block weights, inputs, and outputs for C++ parity testing. + +Strategy: instantiate a TINY AV-enabled BasicAVTransformerBlock with the SAME flags +as the production 22B model (cross_attention_adaln=True, apply_gated_attention=True, +audio dim=64, video dim=128). Run preprocessor + block forward on deterministic random +inputs and dump: + + - block state_dict as named .bin files (raw fp32 bytes per parameter) + - all TransformerArgs fields needed by our C++ forward_av (video + audio sides) + - block outputs (vx_out, ax_out) + +Layout is chosen so raw bytes are interpretable as ggml ne (column-major fastest). +Python's torch tensor [B, T, dim] memory layout matches ggml ne [dim, T, B], so we +write `.numpy().tobytes()` of contiguous tensors directly. + +Outputs: + /tmp/ltx_av_block_ref/ + manifest.json -- catalogue {name -> {shape, path}} + weights/*.bin -- block parameters + inputs/*.bin -- TransformerArgs fields + outputs/*.bin -- vx_out, ax_out + +Usage: + /home/ilintar/venv/bin/python tests/ltx_parity/dump_av_block.py +""" + +from __future__ import annotations + +import json +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.attention import AttentionFunction +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.transformer import ( + BasicAVTransformerBlock, + TransformerConfig, +) +from ltx_core.model.transformer.transformer_args import ( + MultiModalTransformerArgsPreprocessor, +) + + +# ============================================================================ +# Config — tiny dims, FULL feature flags matching 22B +# ============================================================================ + +SEED = 0xA1C0DE +OUT_DIR = pathlib.Path("/tmp/ltx_av_block_ref") +W_DIR = OUT_DIR / "weights" +IN_DIR = OUT_DIR / "inputs" +OUT_DIR_T = OUT_DIR / "outputs" + +# ---- Modality dims ---- +# Video: heads × d_head = inner_dim. Tiny but matches the 22B's "heads, d_head" +# pattern (32×128 there → inner_dim 4096; we use 4×32 → 128). +VIDEO_HEADS = 4 +VIDEO_D_HEAD = 32 +VIDEO_DIM = VIDEO_HEADS * VIDEO_D_HEAD # 128 +# In production 22B: cross_attention_dim == video.dim. The cross_attention_adaln +# path applies prompt_scale_shift_table[shape (2, video.dim)] elementwise to the +# context, which requires context_dim == video.dim. Mirror that here. +VIDEO_CTX_DIM = VIDEO_DIM + +# Audio: 22B uses 32×64 (smaller per-head) → inner_dim 2048. We mirror with 4×16 → 64. +AUDIO_HEADS = 4 +AUDIO_D_HEAD = 16 +AUDIO_DIM = AUDIO_HEADS * AUDIO_D_HEAD # 64 +AUDIO_CTX_DIM = AUDIO_DIM # same constraint as video + +# Sequence lengths (B=1). +B = 1 +F_LAT = 2 # video frames +H_LAT = 3 +W_LAT = 4 +T_VIDEO = F_LAT * H_LAT * W_LAT # 24 video tokens +T_AUDIO = 5 # audio tokens +S_VIDEO = 6 # video text-context tokens +S_AUDIO = 4 # audio text-context tokens + +# ---- Feature flags (match 22B) ---- +CROSS_ATTENTION_ADALN = True +APPLY_GATED_ATTENTION = True +ROPE_TYPE = LTXRopeType.INTERLEAVED # block-level test; INTERLEAVED is well-tested in our C++ path +NORM_EPS = 1e-6 + +# ---- Positional embedding params ---- +VIDEO_POS_DIMS = 3 # (frame, h, w) +AUDIO_POS_DIMS = 1 # frame index only +VIDEO_MAX_POS = [20, 2048, 2048] +AUDIO_MAX_POS = [20] +USE_MIDDLE_INDICES_GRID = True +POS_EMB_THETA = 10000.0 +DOUBLE_PRECISION_ROPE = False +TIMESTEP_SCALE_MULT = 1000 +AV_CA_TS_SCALE_MULT = 1 + + +# ============================================================================ +# Manifest helper +# ============================================================================ + + +@dataclass +class Manifest: + weights: Dict[str, Dict] = field(default_factory=dict) + inputs: Dict[str, Dict] = field(default_factory=dict) + outputs: Dict[str, Dict] = field(default_factory=dict) + config: Dict = field(default_factory=dict) + + def _add(self, target: Dict, name: str, t: torch.Tensor, dest: pathlib.Path): + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + path = dest / (name.replace("/", "__") + ".bin") + path.write_bytes(arr.tobytes()) + target[name] = {"shape": list(arr.shape), "dtype": "float32", + "nbytes": arr.nbytes, "path": str(path.relative_to(OUT_DIR))} + + def add_weight(self, name: str, t: torch.Tensor): + self._add(self.weights, name, t, W_DIR) + + def add_input(self, name: str, t: torch.Tensor): + self._add(self.inputs, name, t, IN_DIR) + + def add_output(self, name: str, t: torch.Tensor): + self._add(self.outputs, name, t, OUT_DIR_T) + + def write(self, path: pathlib.Path): + path.write_text(json.dumps({ + "config": self.config, + "weights": self.weights, + "inputs": self.inputs, + "outputs": self.outputs, + }, indent=2, default=str)) + + +# ============================================================================ +# Build positions for a modality +# ============================================================================ + + +def make_video_positions(F: int, H: int, W: int, device, fps: float = 24.0) -> torch.Tensor: + """Generate INDICES_GRID for video positions. Shape: (B, n_pos_dims=3, T_total). + With use_middle_indices_grid=True we need a (B, n_pos_dims, T, 2) tensor where + [..., 0] = start index and [..., 1] = end index of the indices grid cell. + For a non-VAE-compressed test, start==end == [frame_idx, h, w]. + """ + # Build coords [F*H*W, 3] in (f, h, w) order, then transpose to [3, T]. + coords = [] + for f in range(F): + for h in range(H): + for w in range(W): + coords.append([f, h, w]) + coords = torch.tensor(coords, dtype=torch.float32, device=device).t() # [3, T] + # Replicate start==end along last axis for use_middle_indices_grid=True. + grid = coords.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, 2).contiguous() # [B, 3, T, 2] + return grid + + +def make_audio_positions(T: int, device) -> torch.Tensor: + """Audio positions are 1-D (frame index).""" + coords = torch.arange(T, dtype=torch.float32, device=device).unsqueeze(0) # [1, T] + grid = coords.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, 2).contiguous() # [B, 1, T, 2] + return grid + + +# ============================================================================ +# Main +# ============================================================================ + + +def main() -> None: + torch.manual_seed(SEED) + OUT_DIR.mkdir(parents=True, exist_ok=True) + W_DIR.mkdir(exist_ok=True) + IN_DIR.mkdir(exist_ok=True) + OUT_DIR_T.mkdir(exist_ok=True) + + device = torch.device("cpu") + dtype = torch.float32 + + manifest = Manifest() + manifest.config = { + "video": {"dim": VIDEO_DIM, "heads": VIDEO_HEADS, "d_head": VIDEO_D_HEAD, + "ctx_dim": VIDEO_CTX_DIM, "T": T_VIDEO, "S": S_VIDEO, + "F": F_LAT, "H": H_LAT, "W": W_LAT}, + "audio": {"dim": AUDIO_DIM, "heads": AUDIO_HEADS, "d_head": AUDIO_D_HEAD, + "ctx_dim": AUDIO_CTX_DIM, "T": T_AUDIO, "S": S_AUDIO}, + "B": B, + "cross_attention_adaln": CROSS_ATTENTION_ADALN, + "apply_gated_attention": APPLY_GATED_ATTENTION, + "rope_type": ROPE_TYPE.value, + "norm_eps": NORM_EPS, + "video_max_pos": VIDEO_MAX_POS, + "audio_max_pos": AUDIO_MAX_POS, + "use_middle_indices_grid": USE_MIDDLE_INDICES_GRID, + "pos_emb_theta": POS_EMB_THETA, + "timestep_scale_multiplier": TIMESTEP_SCALE_MULT, + "av_ca_timestep_scale_multiplier": AV_CA_TS_SCALE_MULT, + "audio_cross_attention_dim": AUDIO_CTX_DIM, + "cross_pe_max_pos": max(VIDEO_MAX_POS[0], AUDIO_MAX_POS[0]), + "seed": SEED, + } + + # ---- Build the AV block ---- + video_cfg = TransformerConfig( + dim=VIDEO_DIM, heads=VIDEO_HEADS, d_head=VIDEO_D_HEAD, + context_dim=VIDEO_CTX_DIM, + apply_gated_attention=APPLY_GATED_ATTENTION, + cross_attention_adaln=CROSS_ATTENTION_ADALN, + ) + audio_cfg = TransformerConfig( + dim=AUDIO_DIM, heads=AUDIO_HEADS, d_head=AUDIO_D_HEAD, + context_dim=AUDIO_CTX_DIM, + apply_gated_attention=APPLY_GATED_ATTENTION, + cross_attention_adaln=CROSS_ATTENTION_ADALN, + ) + block = BasicAVTransformerBlock( + idx=0, + video=video_cfg, + audio=audio_cfg, + rope_type=ROPE_TYPE, + norm_eps=NORM_EPS, + attention_function=AttentionFunction.DEFAULT, + ).to(device).eval() + + # Random init all params (skipping bias init that some Linear layers default to zero). + with torch.no_grad(): + for p in block.parameters(): + p.uniform_(-0.05, 0.05) + + # Dump every block parameter. + for name, p in block.state_dict().items(): + manifest.add_weight(name, p) + + # ---- Build the AV preprocessors (model-level adaln modules) ---- + # AdaLayerNormSingle with the right embedding_coefficient values. + # adaln_embedding_coefficient(cross_attention_adaln=True) = 9, False = 6. + coef_main = 9 if CROSS_ATTENTION_ADALN else 6 + video_adaln = AdaLayerNormSingle(VIDEO_DIM, embedding_coefficient=coef_main).to(device).eval() + audio_adaln = AdaLayerNormSingle(AUDIO_DIM, embedding_coefficient=coef_main).to(device).eval() + video_prompt_adaln = AdaLayerNormSingle(VIDEO_DIM, embedding_coefficient=2).to(device).eval() if CROSS_ATTENTION_ADALN else None + audio_prompt_adaln = AdaLayerNormSingle(AUDIO_DIM, embedding_coefficient=2).to(device).eval() if CROSS_ATTENTION_ADALN else None + av_ca_video_ss_adaln = AdaLayerNormSingle(VIDEO_DIM, embedding_coefficient=4).to(device).eval() + av_ca_audio_ss_adaln = AdaLayerNormSingle(AUDIO_DIM, embedding_coefficient=4).to(device).eval() + av_ca_a2v_gate_adaln = AdaLayerNormSingle(VIDEO_DIM, embedding_coefficient=1).to(device).eval() + av_ca_v2a_gate_adaln = AdaLayerNormSingle(AUDIO_DIM, embedding_coefficient=1).to(device).eval() + + with torch.no_grad(): + for m in [video_adaln, audio_adaln, video_prompt_adaln, audio_prompt_adaln, + av_ca_video_ss_adaln, av_ca_audio_ss_adaln, + av_ca_a2v_gate_adaln, av_ca_v2a_gate_adaln]: + if m is None: + continue + for p in m.parameters(): + p.uniform_(-0.05, 0.05) + + # ---- Build patchify projections (linear layers feeding the preprocessor) ---- + # in_channels arbitrary (we'll provide already-patchified x by setting in_channels=dim and using identity-like init? No, just pick small in_channels and let it project up). + VIDEO_IN_CHANNELS = 16 + AUDIO_IN_CHANNELS = 16 + video_patchify = torch.nn.Linear(VIDEO_IN_CHANNELS, VIDEO_DIM, bias=True).to(device).eval() + audio_patchify = torch.nn.Linear(AUDIO_IN_CHANNELS, AUDIO_DIM, bias=True).to(device).eval() + with torch.no_grad(): + for p in video_patchify.parameters(): p.uniform_(-0.05, 0.05) + for p in audio_patchify.parameters(): p.uniform_(-0.05, 0.05) + + # ---- Build preprocessors ---- + cross_pe_max_pos = max(VIDEO_MAX_POS[0], AUDIO_MAX_POS[0]) + video_prep = MultiModalTransformerArgsPreprocessor( + patchify_proj=video_patchify, + adaln=video_adaln, + cross_scale_shift_adaln=av_ca_video_ss_adaln, + cross_gate_adaln=av_ca_a2v_gate_adaln, + inner_dim=VIDEO_DIM, + max_pos=VIDEO_MAX_POS, + num_attention_heads=VIDEO_HEADS, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=USE_MIDDLE_INDICES_GRID, + audio_cross_attention_dim=AUDIO_CTX_DIM, + timestep_scale_multiplier=TIMESTEP_SCALE_MULT, + double_precision_rope=DOUBLE_PRECISION_ROPE, + positional_embedding_theta=POS_EMB_THETA, + rope_type=ROPE_TYPE, + av_ca_timestep_scale_multiplier=AV_CA_TS_SCALE_MULT, + caption_projection=None, + prompt_adaln=video_prompt_adaln, + ) + audio_prep = MultiModalTransformerArgsPreprocessor( + patchify_proj=audio_patchify, + adaln=audio_adaln, + cross_scale_shift_adaln=av_ca_audio_ss_adaln, + cross_gate_adaln=av_ca_v2a_gate_adaln, + inner_dim=AUDIO_DIM, + max_pos=AUDIO_MAX_POS, + num_attention_heads=AUDIO_HEADS, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=USE_MIDDLE_INDICES_GRID, + audio_cross_attention_dim=AUDIO_CTX_DIM, + timestep_scale_multiplier=TIMESTEP_SCALE_MULT, + double_precision_rope=DOUBLE_PRECISION_ROPE, + positional_embedding_theta=POS_EMB_THETA, + rope_type=ROPE_TYPE, + av_ca_timestep_scale_multiplier=AV_CA_TS_SCALE_MULT, + caption_projection=None, + prompt_adaln=audio_prompt_adaln, + ) + + # ---- Build modalities ---- + video_latent = torch.randn(B, T_VIDEO, VIDEO_IN_CHANNELS, dtype=dtype, device=device) + audio_latent = torch.randn(B, T_AUDIO, AUDIO_IN_CHANNELS, dtype=dtype, device=device) + video_context = torch.randn(B, S_VIDEO, VIDEO_CTX_DIM, dtype=dtype, device=device) + audio_context = torch.randn(B, S_AUDIO, AUDIO_CTX_DIM, dtype=dtype, device=device) + # No context mask → full attention. The preprocessor's _prepare_attention_mask + # only converts non-float masks to additive log-space bias; passing None is + # the cleanest way to skip masking entirely on both python and C++ sides. + video_ctx_mask = None + audio_ctx_mask = None + video_timesteps = torch.tensor([0.7], dtype=dtype, device=device) + audio_timesteps = torch.tensor([0.5], dtype=dtype, device=device) + video_sigma = torch.tensor([0.7], dtype=dtype, device=device) + audio_sigma = torch.tensor([0.5], dtype=dtype, device=device) + video_positions = make_video_positions(F_LAT, H_LAT, W_LAT, device) + audio_positions = make_audio_positions(T_AUDIO, device) + + video_modality = Modality( + latent=video_latent, + context=video_context, + context_mask=video_ctx_mask, + timesteps=video_timesteps.unsqueeze(0).unsqueeze(0), # [B, 1, 1] for AdaLN flatten + sigma=video_sigma, + positions=video_positions, + attention_mask=None, + enabled=True, + ) + audio_modality = Modality( + latent=audio_latent, + context=audio_context, + context_mask=audio_ctx_mask, + timesteps=audio_timesteps.unsqueeze(0).unsqueeze(0), + sigma=audio_sigma, + positions=audio_positions, + attention_mask=None, + enabled=True, + ) + + # ---- Run preprocessors ---- + video_args = video_prep.prepare(video_modality, cross_modality=audio_modality) + audio_args = audio_prep.prepare(audio_modality, cross_modality=video_modality) + + # ---- Dump TransformerArgs ---- + # video. + manifest.add_input("video__x", video_args.x) + manifest.add_input("video__context", video_args.context) + if video_args.context_mask is not None: + manifest.add_input("video__context_mask", video_args.context_mask) + manifest.add_input("video__timesteps", video_args.timesteps) + if video_args.prompt_timestep is not None: + manifest.add_input("video__prompt_timestep", video_args.prompt_timestep) + pe_v_cos, pe_v_sin = video_args.positional_embeddings + manifest.add_input("video__pe_cos", pe_v_cos) + manifest.add_input("video__pe_sin", pe_v_sin) + if video_args.cross_positional_embeddings is not None: + cpe_v_cos, cpe_v_sin = video_args.cross_positional_embeddings + manifest.add_input("video__cross_pe_cos", cpe_v_cos) + manifest.add_input("video__cross_pe_sin", cpe_v_sin) + if video_args.cross_scale_shift_timestep is not None: + manifest.add_input("video__cross_scale_shift_timestep", video_args.cross_scale_shift_timestep) + if video_args.cross_gate_timestep is not None: + manifest.add_input("video__cross_gate_timestep", video_args.cross_gate_timestep) + + # audio. + manifest.add_input("audio__x", audio_args.x) + manifest.add_input("audio__context", audio_args.context) + if audio_args.context_mask is not None: + manifest.add_input("audio__context_mask", audio_args.context_mask) + manifest.add_input("audio__timesteps", audio_args.timesteps) + if audio_args.prompt_timestep is not None: + manifest.add_input("audio__prompt_timestep", audio_args.prompt_timestep) + pe_a_cos, pe_a_sin = audio_args.positional_embeddings + manifest.add_input("audio__pe_cos", pe_a_cos) + manifest.add_input("audio__pe_sin", pe_a_sin) + if audio_args.cross_positional_embeddings is not None: + cpe_a_cos, cpe_a_sin = audio_args.cross_positional_embeddings + manifest.add_input("audio__cross_pe_cos", cpe_a_cos) + manifest.add_input("audio__cross_pe_sin", cpe_a_sin) + if audio_args.cross_scale_shift_timestep is not None: + manifest.add_input("audio__cross_scale_shift_timestep", audio_args.cross_scale_shift_timestep) + if audio_args.cross_gate_timestep is not None: + manifest.add_input("audio__cross_gate_timestep", audio_args.cross_gate_timestep) + + # ---- Run the block forward ---- + with torch.no_grad(): + video_out, audio_out = block(video=video_args, audio=audio_args) + + manifest.add_output("video__x_out", video_out.x) + manifest.add_output("audio__x_out", audio_out.x) + + # ---- Write manifest ---- + manifest.write(OUT_DIR / "manifest.json") + print(f"[OK] Wrote {len(manifest.weights)} weights, " + f"{len(manifest.inputs)} inputs, {len(manifest.outputs)} outputs to {OUT_DIR}") + print(f"video.x_out shape={tuple(video_out.x.shape)} audio.x_out shape={tuple(audio_out.x.shape)}") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_av_model.py b/tests/ltx_parity/dump_av_model.py new file mode 100644 index 000000000..077abca11 --- /dev/null +++ b/tests/ltx_parity/dump_av_model.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +"""Dump a tiny LTX-2 AudioVideo model + inputs + outputs for C++ parity testing. + +Build a deterministic-random LTXModel(model_type=AudioVideo, num_layers=2) with +the SAME flags as the production 22B (cross_attention_adaln=True, +apply_gated_attention=True), run forward(video, audio), and dump everything +needed for a C++ side-by-side comparison. + +Outputs: + /tmp/ltx_av_model_ref/ + manifest.json -- catalogue + weights/*.bin -- every model parameter + inputs/*.bin -- modality fields (latents, contexts, sigmas, etc.) PLUS + pre-scaled timesteps that LTXModel.forward_av expects + (caller pre-scales, mirroring the existing video-only + parity test pattern). Also dumps PE cos/sin pairs. + outputs/*.bin -- vx_out, ax_out + +Usage: + /home/ilintar/venv/bin/python tests/ltx_parity/dump_av_model.py +""" + +from __future__ import annotations + +import json +import pathlib +from dataclasses import dataclass, field +from typing import Dict + +import torch + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig +from ltx_core.model.transformer.attention import AttentionFunction +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.model import LTXModel, LTXModelType +from ltx_core.model.transformer.rope import LTXRopeType + + +# ============================================================================ +# Config +# ============================================================================ + +SEED = 0xCA11AB1E +OUT_DIR = pathlib.Path("/tmp/ltx_av_model_ref") +W_DIR = OUT_DIR / "weights" +IN_DIR = OUT_DIR / "inputs" +OUT_DIR_T = OUT_DIR / "outputs" + +# Tiny dims, full-feature flags. Audio inner_dim = audio.heads * audio.d_head. +# Production 22B uses video 32×128, audio 32×64; we mirror with 4×32 / 4×16. +VIDEO_HEADS, VIDEO_D_HEAD = 4, 32 +AUDIO_HEADS, AUDIO_D_HEAD = 4, 16 +VIDEO_DIM = VIDEO_HEADS * VIDEO_D_HEAD # 128 +AUDIO_DIM = AUDIO_HEADS * AUDIO_D_HEAD # 64 + +# Production has cross_attention_dim == video.dim and audio_cross_attention_dim +# == audio.dim (the cross_attention_adaln modulation requires it). +VIDEO_CROSS_ATTN_DIM = VIDEO_DIM +AUDIO_CROSS_ATTN_DIM = AUDIO_DIM + +VIDEO_IN_CHANNELS = 16 +VIDEO_OUT_CHANNELS = 16 +AUDIO_IN_CHANNELS = 8 +AUDIO_OUT_CHANNELS = 8 + +NUM_LAYERS = 2 + +CROSS_ATTENTION_ADALN = True +APPLY_GATED_ATTENTION = True +ROPE_TYPE = LTXRopeType.INTERLEAVED +NORM_EPS = 1e-6 + +VIDEO_MAX_POS = [20, 2048, 2048] +AUDIO_MAX_POS = [20] +USE_MIDDLE_INDICES_GRID = True +POS_EMB_THETA = 10000.0 +DOUBLE_PRECISION_ROPE = False +TIMESTEP_SCALE_MULT = 1000 +AV_CA_TS_SCALE_MULT = 1 + +B = 1 +F_LAT, H_LAT, W_LAT = 2, 3, 4 # video patch grid +T_VIDEO = F_LAT * H_LAT * W_LAT # 24 video tokens +T_AUDIO = 5 +S_VIDEO = 6 +S_AUDIO = 4 + + +# ============================================================================ +# Manifest helper +# ============================================================================ + + +@dataclass +class Manifest: + weights: Dict[str, Dict] = field(default_factory=dict) + inputs: Dict[str, Dict] = field(default_factory=dict) + outputs: Dict[str, Dict] = field(default_factory=dict) + config: Dict = field(default_factory=dict) + + def _add(self, target, name, t, dest): + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + path = dest / (name.replace("/", "__") + ".bin") + path.write_bytes(arr.tobytes()) + target[name] = {"shape": list(arr.shape), "path": str(path.relative_to(OUT_DIR))} + + def add_w(self, name, t): self._add(self.weights, name, t, W_DIR) + def add_i(self, name, t): self._add(self.inputs, name, t, IN_DIR) + def add_o(self, name, t): self._add(self.outputs, name, t, OUT_DIR_T) + + def write(self, path): + path.write_text(json.dumps({ + "config": self.config, + "weights": self.weights, + "inputs": self.inputs, + "outputs": self.outputs, + }, indent=2, default=str)) + + +# ============================================================================ +# Position grids +# ============================================================================ + + +def make_video_positions(F: int, H: int, W: int, device) -> torch.Tensor: + coords = [] + for f in range(F): + for h in range(H): + for w in range(W): + coords.append([f, h, w]) + coords = torch.tensor(coords, dtype=torch.float32, device=device).t() # [3, T] + grid = coords.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, 2).contiguous() # [B, 3, T, 2] + return grid + + +def make_audio_positions(T: int, device) -> torch.Tensor: + coords = torch.arange(T, dtype=torch.float32, device=device).unsqueeze(0) # [1, T] + grid = coords.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, 2).contiguous() + return grid + + +# ============================================================================ +# Main +# ============================================================================ + + +def main() -> None: + torch.manual_seed(SEED) + OUT_DIR.mkdir(parents=True, exist_ok=True) + W_DIR.mkdir(exist_ok=True) + IN_DIR.mkdir(exist_ok=True) + OUT_DIR_T.mkdir(exist_ok=True) + + device = torch.device("cpu") + dtype = torch.float32 + + manifest = Manifest() + manifest.config = { + "video": {"in_channels": VIDEO_IN_CHANNELS, "out_channels": VIDEO_OUT_CHANNELS, + "dim": VIDEO_DIM, "heads": VIDEO_HEADS, "d_head": VIDEO_D_HEAD, + "ctx_dim": VIDEO_CROSS_ATTN_DIM, "T": T_VIDEO, "S": S_VIDEO, + "F": F_LAT, "H": H_LAT, "W": W_LAT}, + "audio": {"in_channels": AUDIO_IN_CHANNELS, "out_channels": AUDIO_OUT_CHANNELS, + "dim": AUDIO_DIM, "heads": AUDIO_HEADS, "d_head": AUDIO_D_HEAD, + "ctx_dim": AUDIO_CROSS_ATTN_DIM, "T": T_AUDIO, "S": S_AUDIO}, + "B": B, + "num_layers": NUM_LAYERS, + "cross_attention_adaln": CROSS_ATTENTION_ADALN, + "apply_gated_attention": APPLY_GATED_ATTENTION, + "rope_type": ROPE_TYPE.value, + "norm_eps": NORM_EPS, + "video_max_pos": VIDEO_MAX_POS, + "audio_max_pos": AUDIO_MAX_POS, + "timestep_scale_multiplier": TIMESTEP_SCALE_MULT, + "av_ca_timestep_scale_multiplier": AV_CA_TS_SCALE_MULT, + "audio_cross_attention_dim": AUDIO_CROSS_ATTN_DIM, + "cross_pe_max_pos": max(VIDEO_MAX_POS[0], AUDIO_MAX_POS[0]), + "seed": SEED, + } + + # ---- Build the LTXModel ---- + model = LTXModel( + model_type=LTXModelType.AudioVideo, + num_attention_heads=VIDEO_HEADS, + attention_head_dim=VIDEO_D_HEAD, + in_channels=VIDEO_IN_CHANNELS, + out_channels=VIDEO_OUT_CHANNELS, + num_layers=NUM_LAYERS, + cross_attention_dim=VIDEO_CROSS_ATTN_DIM, + norm_eps=NORM_EPS, + attention_type=AttentionFunction.DEFAULT, + positional_embedding_theta=POS_EMB_THETA, + positional_embedding_max_pos=VIDEO_MAX_POS, + timestep_scale_multiplier=TIMESTEP_SCALE_MULT, + use_middle_indices_grid=USE_MIDDLE_INDICES_GRID, + audio_num_attention_heads=AUDIO_HEADS, + audio_attention_head_dim=AUDIO_D_HEAD, + audio_in_channels=AUDIO_IN_CHANNELS, + audio_out_channels=AUDIO_OUT_CHANNELS, + audio_cross_attention_dim=AUDIO_CROSS_ATTN_DIM, + audio_positional_embedding_max_pos=AUDIO_MAX_POS, + av_ca_timestep_scale_multiplier=AV_CA_TS_SCALE_MULT, + rope_type=ROPE_TYPE, + double_precision_rope=DOUBLE_PRECISION_ROPE, + apply_gated_attention=APPLY_GATED_ATTENTION, + caption_projection=None, + audio_caption_projection=None, + cross_attention_adaln=CROSS_ATTENTION_ADALN, + ).to(device).eval() + + # Random init. + with torch.no_grad(): + for p in model.parameters(): + p.uniform_(-0.05, 0.05) + + # Dump every state_dict tensor. + for name, p in model.state_dict().items(): + manifest.add_w(name, p) + print(f"weights: {len(manifest.weights)}") + + # ---- Build modalities and run forward ---- + video_latent = torch.randn(B, T_VIDEO, VIDEO_IN_CHANNELS, dtype=dtype, device=device) + audio_latent = torch.randn(B, T_AUDIO, AUDIO_IN_CHANNELS, dtype=dtype, device=device) + video_context = torch.randn(B, S_VIDEO, VIDEO_CROSS_ATTN_DIM, dtype=dtype, device=device) + audio_context = torch.randn(B, S_AUDIO, AUDIO_CROSS_ATTN_DIM, dtype=dtype, device=device) + video_sigma = torch.tensor([0.7], dtype=dtype, device=device) + audio_sigma = torch.tensor([0.5], dtype=dtype, device=device) + + video_modality = Modality( + latent=video_latent, + context=video_context, + context_mask=None, + timesteps=video_sigma.view(B, 1, 1), + sigma=video_sigma, + positions=make_video_positions(F_LAT, H_LAT, W_LAT, device), + attention_mask=None, + enabled=True, + ) + audio_modality = Modality( + latent=audio_latent, + context=audio_context, + context_mask=None, + timesteps=audio_sigma.view(B, 1, 1), + sigma=audio_sigma, + positions=make_audio_positions(T_AUDIO, device), + attention_mask=None, + enabled=True, + ) + + perturbations = BatchedPerturbationConfig.empty(B) + with torch.no_grad(): + v_out, a_out = model(video=video_modality, audio=audio_modality, perturbations=perturbations) + + # ---- Inputs for the C++ side ---- + # video and audio latents (pre-patchify; LTXModel.forward_av runs patchify itself). + manifest.add_i("video__latent", video_latent) + manifest.add_i("audio__latent", audio_latent) + manifest.add_i("video__context", video_context) + manifest.add_i("audio__context", audio_context) + + # Pre-scaled timesteps. Our C++ LTXModel.forward_av takes them already scaled + # (mirroring the existing video-only LTXModel.forward convention). + v_t_self = video_sigma * TIMESTEP_SCALE_MULT + a_t_self = audio_sigma * TIMESTEP_SCALE_MULT + v_t_prompt_self = v_t_self # cross_attention_adaln uses same σ scaling + a_t_prompt_self = a_t_self + # Cross-modality timesteps follow MultiModalTransformerArgsPreprocessor: + # cross_modality.sigma * timestep_scale_multiplier for the scale_shift adaln, + # cross_modality.sigma * av_ca_timestep_scale_multiplier for the gate adaln. + v_t_cross_ss = audio_sigma * TIMESTEP_SCALE_MULT + a_t_cross_ss = video_sigma * TIMESTEP_SCALE_MULT + v_t_cross_gate = audio_sigma * AV_CA_TS_SCALE_MULT + a_t_cross_gate = video_sigma * AV_CA_TS_SCALE_MULT + + manifest.add_i("video__t_self", v_t_self) + manifest.add_i("audio__t_self", a_t_self) + manifest.add_i("video__t_prompt_self", v_t_prompt_self) + manifest.add_i("audio__t_prompt_self", a_t_prompt_self) + manifest.add_i("video__t_cross_ss", v_t_cross_ss) + manifest.add_i("audio__t_cross_ss", a_t_cross_ss) + manifest.add_i("video__t_cross_gate", v_t_cross_gate) + manifest.add_i("audio__t_cross_gate", a_t_cross_gate) + + # Positional embeddings — re-derive via the model's preprocessors so we get + # exactly what the python forward saw. Then dump cos/sin separately so the + # C++ side can pack them into our [inner_dim, T, 2] layout. + video_args = model.video_args_preprocessor.prepare(video_modality, audio_modality) + audio_args = model.audio_args_preprocessor.prepare(audio_modality, video_modality) + pe_v_cos, pe_v_sin = video_args.positional_embeddings + pe_a_cos, pe_a_sin = audio_args.positional_embeddings + cpe_v_cos, cpe_v_sin = video_args.cross_positional_embeddings + cpe_a_cos, cpe_a_sin = audio_args.cross_positional_embeddings + manifest.add_i("video__pe_cos", pe_v_cos) + manifest.add_i("video__pe_sin", pe_v_sin) + manifest.add_i("audio__pe_cos", pe_a_cos) + manifest.add_i("audio__pe_sin", pe_a_sin) + manifest.add_i("video__cross_pe_cos", cpe_v_cos) + manifest.add_i("video__cross_pe_sin", cpe_v_sin) + manifest.add_i("audio__cross_pe_cos", cpe_a_cos) + manifest.add_i("audio__cross_pe_sin", cpe_a_sin) + + manifest.add_o("video__x_out", v_out) + manifest.add_o("audio__x_out", a_out) + + manifest.write(OUT_DIR / "manifest.json") + print(f"[OK] Wrote {len(manifest.weights)} weights, {len(manifest.inputs)} inputs, " + f"{len(manifest.outputs)} outputs to {OUT_DIR}") + print(f"video.x_out shape={tuple(v_out.shape)} audio.x_out shape={tuple(a_out.shape)}") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_connector.py b/tests/ltx_parity/dump_connector.py new file mode 100644 index 000000000..76aa2f695 --- /dev/null +++ b/tests/ltx_parity/dump_connector.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +"""Dump tiny LTX-2 Connector V1 reference tensors for C++/GGML parity testing. + +Covers: +- FeatureExtractorV1 (masked norm + aggregate_embed Linear) +- Embeddings1DConnector (2× BasicTransformerBlock1D + final rms_norm, with + num_learnable_registers weights present but unused on all-ones mask path) +- PixArtAlphaTextProjection (caption_projection inside the DiT) + +Usage: + /home/ilintar/venv/bin/python dump_connector.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file + +from ltx_core.text_encoders.gemma.embeddings_connector import Embeddings1DConnector +from ltx_core.text_encoders.gemma.embeddings_processor import ( + EmbeddingsProcessor, + convert_to_additive_mask, +) +from ltx_core.text_encoders.gemma.feature_extractor import FeatureExtractorV1 +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection + + +SEED = 0 + +# Two variants to exercise different connector paths: +# - "nopad" (default): SEQ_LEN=8, NUM_REGISTERS=4, mask all-ones. Register +# replacement is a no-op (reals fill everything) — +# covers the "skip concat" branch in C++. +# - "padded" (env CONNECTOR_VARIANT=padded): SEQ_LEN=8, NUM_REGISTERS=8, +# T_REAL=3 with left-padded mask [0,0,0,0,0,1,1,1]. +# Register replacement moves reals to the front and +# fills positions [T_REAL, num_reg) with the trailing +# slice of learnable_registers — this is the path the +# production conditioner/LTX2ConnectorRunner now takes +# when T_real < num_registers. +import os +VARIANT = os.environ.get("CONNECTOR_VARIANT", "nopad") +assert VARIANT in ("nopad", "padded") + +OUT_DIR = pathlib.Path("/tmp/connector_ref" if VARIANT == "nopad" else "/tmp/connector_ref_padded") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny config (mirrors real LTX-2 head_dim=128 for fp16-stable attention; +# 2 heads keeps inner_dim small enough for fast parity). +NUM_HEADS = 2 +HEAD_DIM = 32 +INNER_DIM = NUM_HEADS * HEAD_DIM # 64 +NUM_LAYERS = 2 +ROPE_THETA = 10_000.0 +ROPE_MAX_POS = [1] + +FEAT_NUM_LAYERS = 5 # fake "embed + 4 transformer layers" +FLAT_DIM = INNER_DIM * FEAT_NUM_LAYERS # 80 + +CAPTION_CHANNELS = INNER_DIM # 64 +CAPTION_HIDDEN = 128 # DiT inner dim (larger than connector) +CAPTION_OUT = CAPTION_HIDDEN # default: = hidden_size + +BATCH = 1 + +if VARIANT == "nopad": + NUM_REGISTERS = 4 + SEQ_LEN = 8 # > num_reg so register replacement is a no-op + T_REAL = 8 # entire SEQ_LEN is real tokens +else: # padded + NUM_REGISTERS = 8 + SEQ_LEN = 8 # == num_reg (Python requires SEQ_LEN % NUM_REGISTERS == 0) + T_REAL = 3 # left-padded: only last 3 positions are real + +assert SEQ_LEN % NUM_REGISTERS == 0 + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe_name = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe_name}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() == 1: + # RMSNorm weights (standard, not Gemma's (1+w)) etc.: keep at 1 so + # effective scale is identity at init. For plain biases zero is + # also fine; we just keep the default shape. + if "norm" in name.lower() or "weight" == name.split(".")[-1] and p.shape[0] == INNER_DIM: + p.fill_(1.0) + else: + p.zero_() + elif p.dim() == 2: + fan_in = p.shape[1] + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + else: + p.normal_(mean=0.0, std=0.02, generator=g) + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + # --- Build modules (tiny). --- + aggregate_embed = torch.nn.Linear(FLAT_DIM, INNER_DIM, bias=False) + feature_extractor = FeatureExtractorV1(aggregate_embed=aggregate_embed, is_av=False) + + connector = Embeddings1DConnector( + attention_head_dim=HEAD_DIM, + num_attention_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + positional_embedding_theta=ROPE_THETA, + positional_embedding_max_pos=ROPE_MAX_POS, + causal_temporal_positioning=False, + num_learnable_registers=NUM_REGISTERS, + rope_type=LTXRopeType.INTERLEAVED, + # True = numpy fp64 linspace + pow cast to fp32 at the end. Matches our + # C++ fp64 path byte-exactly. With False, torch's fp32 pow drifts 1 ULP + # at the tail of the grid, causing ~5e-2 cos/sin diffs we can't reproduce. + double_precision_rope=True, + apply_gated_attention=False, + ) + + caption_projection = PixArtAlphaTextProjection( + in_features=CAPTION_CHANNELS, + hidden_size=CAPTION_HIDDEN, + out_features=CAPTION_OUT, + act_fn="gelu_tanh", + ) + + # Tame weights deterministically. + tame_(feature_extractor) + tame_(connector) + tame_(caption_projection) + + # Cast to float32 (the tame() doesn't touch registers which default to bfloat16). + with torch.no_grad(): + if hasattr(connector, "learnable_registers"): + g = torch.Generator().manual_seed(SEED + 1) + connector.learnable_registers.data = ( + torch.rand(NUM_REGISTERS, INNER_DIM, generator=g) * 2.0 - 1.0 + ).to(torch.float32) + + feature_extractor.eval() + connector.eval() + caption_projection.eval() + + # --- Build inputs. --- + rng = np.random.default_rng(SEED + 2) + # Pretend 49-layer stack (tiny): [B, T, D=INNER_DIM, L=FEAT_NUM_LAYERS] + stacked = torch.tensor( + rng.normal(loc=0.0, scale=1.0, size=(BATCH, SEQ_LEN, INNER_DIM, FEAT_NUM_LAYERS)), + dtype=torch.float32, + ) + # Binary attention mask. Left-padded when VARIANT="padded": first + # (SEQ_LEN - T_REAL) positions are pad (0), last T_REAL are real (1). + attention_mask = torch.ones((BATCH, SEQ_LEN), dtype=torch.int64) + if VARIANT == "padded": + attention_mask[:, : SEQ_LEN - T_REAL] = 0 + # Zero-out the padded positions in the stacked input too, matching what + # the real HF pipeline feeds (padded tokens have zero embeddings after + # feature extraction since Gemma's pad_token embedding is unused in the + # text-to-video pipeline — FeatureExtractor masks them out anyway). + stacked[:, : SEQ_LEN - T_REAL, :, :] = 0 + + manifest = Manifest() + save_tensor(stacked, "stacked_in", manifest) + save_tensor(attention_mask.to(torch.float32), "attention_mask", manifest) + + # --- 1. Feature extractor. --- + with torch.no_grad(): + feat_out, _ = feature_extractor(stacked, attention_mask, padding_side="left") + save_tensor(feat_out, "feat_ext_out", manifest) + print(f" feat_ext_out shape={tuple(feat_out.shape)} " + f"mean={feat_out.mean().item():.4f} std={feat_out.std().item():.4f}") + + # --- 2. Connector. --- + additive_mask = convert_to_additive_mask(attention_mask, feat_out.dtype) + # Run connector piece-by-piece to capture intermediates. + with torch.no_grad(): + hs = feat_out + am = additive_mask + # Register replacement (no-op for all-ones mask, but exercises the path). + if connector.num_learnable_registers: + hs, am = connector._replace_padded_with_learnable_registers(hs, am) + save_tensor(hs, "after_registers", manifest) + + indices_grid = torch.arange(hs.shape[1], dtype=torch.float32) + indices_grid = indices_grid[None, None, :] + from ltx_core.model.transformer.rope import ( + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, + ) + freq_gen = generate_freq_grid_np if connector.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=connector.inner_dim, + out_dtype=hs.dtype, + theta=connector.positional_embedding_theta, + max_pos=connector.positional_embedding_max_pos, + num_attention_heads=connector.num_attention_heads, + rope_type=connector.rope_type, + freq_grid_generator=freq_gen, + ) + cos_f, sin_f = freqs_cis + save_tensor(cos_f, "rope_cos", manifest) + save_tensor(sin_f, "rope_sin", manifest) + + for i, block in enumerate(connector.transformer_1d_blocks): + hs = block(hs, attention_mask=am, pe=freqs_cis) + save_tensor(hs, f"conn_block_{i}_out", manifest) + print(f" conn_block_{i}_out shape={tuple(hs.shape)} " + f"mean={hs.mean().item():.4f} std={hs.std().item():.4f}") + + from ltx_core.utils import rms_norm + hs = rms_norm(hs) + save_tensor(hs, "conn_final_out", manifest) + print(f" conn_final_out shape={tuple(hs.shape)} " + f"mean={hs.mean().item():.4f} std={hs.std().item():.4f}") + + # --- 3. Caption projection. --- + with torch.no_grad(): + caption_out = caption_projection(hs) + save_tensor(caption_out, "caption_proj_out", manifest) + print(f" caption_proj_out shape={tuple(caption_out.shape)} " + f"mean={caption_out.mean().item():.4f} std={caption_out.std().item():.4f}") + + # --- Save state dict under C++-friendly keys. --- + state: Dict[str, torch.Tensor] = {} + # Feature extractor + state["feature_extractor.aggregate_embed.weight"] = ( + feature_extractor.aggregate_embed.weight.detach().to(torch.float32).contiguous() + ) + # Connector parameters + for key, value in connector.state_dict().items(): + state[f"connector.{key}"] = value.detach().to(torch.float32).contiguous() + # Caption projection + for key, value in caption_projection.state_dict().items(): + state[f"caption_projection.{key}"] = value.detach().to(torch.float32).contiguous() + + save_file(state, str(OUT_DIR / "state_dict.safetensors")) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(state.keys())) + "\n") + manifest.dump(OUT_DIR / "manifest.json") + + (OUT_DIR / "config.json").write_text(json.dumps({ + "num_heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "inner_dim": INNER_DIM, + "num_layers": NUM_LAYERS, + "num_registers": NUM_REGISTERS, + "rope_theta": ROPE_THETA, + "rope_max_pos": ROPE_MAX_POS, + "feat_num_layers": FEAT_NUM_LAYERS, + "flat_dim": FLAT_DIM, + "caption_channels": CAPTION_CHANNELS, + "caption_hidden": CAPTION_HIDDEN, + "caption_out": CAPTION_OUT, + "seq_len": SEQ_LEN, + "batch": BATCH, + }, indent=2)) + + print(f"\nDone. {len(manifest.entries)} tensors → {OUT_DIR}") + print(f"State dict: {len(state)} keys → {OUT_DIR}/state_dict.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_gemma.py b/tests/ltx_parity/dump_gemma.py new file mode 100644 index 000000000..b295f9d04 --- /dev/null +++ b/tests/ltx_parity/dump_gemma.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +"""Dump tiny Gemma 3 reference tensors for C++/GGML parity testing. + +Strategy mirrors dump_reference.py: instantiate a tiny Gemma3TextModel (6 layers so +sliding-window pattern triggers at layer 5, small dims) with deterministic tamed +weights, run one forward pass on fixed input_ids, and write every intermediate tensor +(embedding-post-scale, per-layer output, all-layer stack, final norm) to +/tmp/gemma_ref/tensors/ as raw fp32 bytes. Also write the state_dict as safetensors +so the C++ side can load identical weights. + +Usage: + /home/ilintar/venv/bin/python dump_gemma.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file +from transformers import Gemma3TextConfig, Gemma3TextModel + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/gemma_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Select config via GEMMA_PARITY_VARIANT env var: "tiny" (default) or "deep". +# The tiny variant is fast but only exercises 6 stacked layers with tamed weights. +# The deep variant scales to 24 layers × 512 hidden to stress-test accumulated drift +# and the full sliding/global interleave pattern (same sliding_window_pattern=6 as +# the real Gemma 3 12B). Shared code path in both — differences are pure scaling. +import os +VARIANT = os.environ.get("GEMMA_PARITY_VARIANT", "tiny") + +if VARIANT == "deep": + NUM_LAYERS = 24 + HIDDEN_SIZE = 512 + NUM_HEADS = 8 + NUM_KV_HEADS = 4 + HEAD_DIM = 64 + INTERMEDIATE_SIZE = 1024 + VOCAB_SIZE = 1024 + SLIDING_WINDOW = 16 + SLIDING_WINDOW_PATTERN = 6 + SEQ_LEN = 32 # > sliding_window so the sliding mask actually bites + # Each "deep" run reuses /tmp/gemma_ref but under a distinct tensor prefix so + # test_gemma_parity.cpp can load both files without key collision. + TENSOR_PREFIX_MODEL = "text_encoder_deep.model" + TENSOR_TAG_PREFIX = "deep_" # applied to tensor output filenames +else: + # Tiny Gemma 3 config. 6 layers so layer index 5 is the first (and only) global + # layer under the (i+1)%6 rule — exercises both sliding and full paths. + NUM_LAYERS = 6 + HIDDEN_SIZE = 128 + NUM_HEADS = 4 + NUM_KV_HEADS = 2 + HEAD_DIM = 32 # NOTE: != HIDDEN_SIZE / NUM_HEADS. Matches Gemma's non-standard head_dim. + INTERMEDIATE_SIZE = 256 + VOCAB_SIZE = 512 + SLIDING_WINDOW = 4 + SLIDING_WINDOW_PATTERN = 6 + SEQ_LEN = 8 + TENSOR_PREFIX_MODEL = "text_encoder.model" + TENSOR_TAG_PREFIX = "" + +RMS_EPS = 1e-6 +ROPE_THETA = 1_000_000.0 +ROPE_LOCAL_THETA = 10_000.0 +BATCH = 1 + +TENSOR_PREFIX = TENSOR_PREFIX_MODEL # Our LLM wrapper stores TextModel under .model, + # so the full key is prefix.model.. + + +# -------- Utility -------- + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe_name = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe_name}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + """Apply deterministic, finite weights. Mirrors dump_reference.py's approach: + RMSNorm weights = 1, linears ~= Kaiming with a smaller gain. + """ + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() == 1: + # All 1D params: RMS weights, LN weights, biases. Gemma uses RMS with + # `weight = zeros` + `(1 + weight)` pattern; see below — keep at 0 so + # effective scale is 1.0 at init. + p.zero_() + elif p.dim() == 2: + fan_in = p.shape[1] + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + else: + p.normal_(mean=0.0, std=0.02, generator=g) + + +# -------- Main -------- + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + # Real Gemma 3 12B config has rope_scaling={"rope_type": "linear", "factor": 8.0} + # applied to full_attention layers only (HuggingFace gemma3 config.json). Mirror that + # here in the deep variant so C++ parity actually exercises the scaling path. The + # tiny variant keeps scaling disabled (factor=1) for faster iteration / backward compat. + rope_scaling = {"rope_type": "linear", "factor": 8.0} if VARIANT == "deep" else None + config = Gemma3TextConfig( + vocab_size=VOCAB_SIZE, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + rms_norm_eps=RMS_EPS, + rope_theta=ROPE_THETA, + rope_local_base_freq=ROPE_LOCAL_THETA, + rope_scaling=rope_scaling, + sliding_window=SLIDING_WINDOW, + sliding_window_pattern=SLIDING_WINDOW_PATTERN, + max_position_embeddings=1024, + attention_bias=False, + attn_logit_softcapping=None, + final_logit_softcapping=None, + query_pre_attn_scalar=HEAD_DIM, # 1/sqrt(head_dim) scaling + hidden_activation="gelu_pytorch_tanh", + ) + + print("Config summary:") + print(f" layer_types: {config.layer_types}") + print(f" hidden_size: {config.hidden_size}") + print(f" head_dim: {config.head_dim}") + print(f" sliding_window: {config.sliding_window}") + + model = Gemma3TextModel(config) + model.eval() + tame_(model) + + # Fixed input ids. + rng = np.random.default_rng(SEED) + input_ids = torch.tensor(rng.integers(low=0, high=VOCAB_SIZE, size=(BATCH, SEQ_LEN)), dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + manifest = Manifest() + save_tensor(input_ids.to(torch.float32), f"{TENSOR_TAG_PREFIX}input_ids", manifest) # store as f32 for simplicity + + # Forward with output_hidden_states=True to capture every layer. + with torch.no_grad(): + out = model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # out.hidden_states is a tuple: (embedding_out, layer_0_out, layer_1_out, ..., layer_N_out) + # So len == num_layers + 1. First element is post-embed-scale. + hidden_states = out.hidden_states + assert len(hidden_states) == NUM_LAYERS + 1, f"got {len(hidden_states)} hidden states" + + for i, h in enumerate(hidden_states): + tag = f"{TENSOR_TAG_PREFIX}hs_{i:02d}" if i > 0 else f"{TENSOR_TAG_PREFIX}hs_embed" + save_tensor(h, tag, manifest) + if VARIANT == "tiny" or i % 4 == 0 or i == len(hidden_states) - 1: + # Keep logs short for the deep variant. + print(f" {tag}: shape={tuple(h.shape)} mean={h.mean().item():.4f} std={h.std().item():.4f}") + + # Final norm'd output (post model.norm, which is Gemma's output rms). + save_tensor(out.last_hidden_state, f"{TENSOR_TAG_PREFIX}last_hidden_state", manifest) + + # Stacked all-(N+1)-layer tensor as LTX-2 consumes it: + # torch.stack(hidden_states, dim=-1) -> [B, T, H, N+1] + stacked = torch.stack(hidden_states, dim=-1) + save_tensor(stacked, f"{TENSOR_TAG_PREFIX}all_layers_stacked", manifest) + print(f" all_layers_stacked: shape={tuple(stacked.shape)}") + + # Write state dict with our C++-side prefix convention. + # If a prior run (e.g. "tiny" → then "deep") already wrote a state_dict with a + # different prefix, merge instead of overwriting so both variants live in one + # safetensors file and the C++ test can load either config path on demand. + state_dict = model.state_dict() + prefixed = {f"{TENSOR_PREFIX}.{k}": v.to(torch.float32).contiguous() for k, v in state_dict.items()} + sd_path = OUT_DIR / "state_dict.safetensors" + if sd_path.exists(): + try: + from safetensors.torch import load_file + existing = load_file(str(sd_path)) + for k, v in existing.items(): + if k.startswith(f"{TENSOR_PREFIX}."): + continue # replace our own prefix on re-run + prefixed[k] = v + print(f" merged {len(existing)} existing tensors into new state_dict") + except Exception as e: + print(f" warning: could not merge existing state_dict ({e}); overwriting") + save_file(prefixed, str(sd_path)) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(prefixed.keys())) + "\n") + manifest.dump(OUT_DIR / f"manifest{'_deep' if VARIANT == 'deep' else ''}.json") + + # Also dump config JSON so C++ side can cross-check shapes if needed. + (OUT_DIR / "config.json").write_text(json.dumps({ + "num_layers": NUM_LAYERS, + "hidden_size": HIDDEN_SIZE, + "num_heads": NUM_HEADS, + "num_kv_heads": NUM_KV_HEADS, + "head_dim": HEAD_DIM, + "intermediate_size": INTERMEDIATE_SIZE, + "vocab_size": VOCAB_SIZE, + "rms_norm_eps": RMS_EPS, + "sliding_window": SLIDING_WINDOW, + "sliding_window_pattern": SLIDING_WINDOW_PATTERN, + "rope_theta_global": ROPE_THETA, + "rope_theta_local": ROPE_LOCAL_THETA, + "seq_len": SEQ_LEN, + "batch": BATCH, + "embed_scale": math.sqrt(HIDDEN_SIZE), + "layer_types": config.layer_types, + "tensor_prefix": TENSOR_PREFIX, + }, indent=2)) + + print(f"\nDone. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(prefixed)} keys → {OUT_DIR}/state_dict.safetensors") + print(f"Manifest: {OUT_DIR}/manifest.json") + print(f"Name list: {OUT_DIR}/tensor_names.txt") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_reference.py b/tests/ltx_parity/dump_reference.py new file mode 100644 index 000000000..97f8476b9 --- /dev/null +++ b/tests/ltx_parity/dump_reference.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +"""Dump LTX-2 reference tensors for C++/GGML parity testing. + +Strategy: instantiate a TINY LTX-2 model (2 layers, small dims) with deterministic +random weights, run a single forward pass on fixed inputs, and write every intermediate +tensor (post-each-block, post-AdaLN, post-patchify, final output) to +/tmp/ltx_ref/tensors/ as raw fp32 bytes. Also dump the state_dict as safetensors so the +C++ side can load the exact same weights. + +Usage: + /home/ilintar/venv/bin/python dump_reference.py + +Outputs: + /tmp/ltx_ref/manifest.json -- catalogue of every dumped tensor + /tmp/ltx_ref/state_dict.safetensors -- model weights + /tmp/ltx_ref/tensors/*.bin -- raw fp32 bytes, one file per tensor + /tmp/ltx_ref/tensor_names.txt -- state_dict.keys() for name-mapping verification +""" + +from __future__ import annotations + +import json +import os +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from safetensors.torch import save_file + +from ltx_core.components.schedulers import LTX2Scheduler +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.attention import Attention, AttentionFunction +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.model import LTXModel, LTXModelType +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import ( + LTXRopeType, + apply_rotary_emb, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) +from ltx_core.model.transformer.timestep_embedding import ( + PixArtAlphaCombinedTimestepSizeEmbeddings, +) +from ltx_core.guidance.perturbations import BatchedPerturbationConfig + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/ltx_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny model config — deliberately small so every tensor is cheap to dump. +INNER_DIM = 128 +NUM_HEADS = 4 +HEAD_DIM = 32 # NUM_HEADS * HEAD_DIM = INNER_DIM +NUM_LAYERS = 2 +IN_CHANNELS = 16 +OUT_CHANNELS = 16 +CROSS_ATTN_DIM = 128 # keep == INNER_DIM to avoid needing caption_projection for now +NORM_EPS = 1e-6 + +# Toy latent (F, H, W) — small but with at least 2 frames to exercise temporal axis. +F_LAT, H_LAT, W_LAT = 2, 4, 6 +BATCH = 1 +FPS = 24.0 + +# Synthetic text context. +CONTEXT_LEN = 8 + + +# -------- Utility -------- + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, tensor: torch.Tensor, notes: str = ""): + t = tensor.detach().to(torch.float32).contiguous().cpu() + # Flatten name → filename by replacing '/' with '__' so everything lives in one dir. + fname = name.replace("/", "__") + ".bin" + path = TENSOR_DIR / fname + path.write_bytes(t.numpy().tobytes()) + self.entries.append( + { + "name": name, + "shape": list(t.shape), + "dtype": "float32", + "nbytes": t.numel() * 4, + "path": str(path.relative_to(OUT_DIR)), + "notes": notes, + } + ) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def seeded_randn(shape, seed_offset=0): + g = torch.Generator().manual_seed(SEED + seed_offset) + return torch.randn(shape, generator=g, dtype=torch.float32) + + +# -------- Dumpers -------- + + +def dump_rope(): + """Dump RoPE freqs_cis + apply_rotary_emb result for a known grid.""" + # 3D positions, middle-grid form: shape [B, n_pos_dims, T, 2] with (start, end) pairs. + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + # Time axis divided by fps per ltx_pipelines/utils/tools.py:135. + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + cos, sin = precompute_freqs_cis( + positions, + dim=INNER_DIM, + out_dtype=torch.float32, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=NUM_HEADS, + rope_type=LTXRopeType.SPLIT, + freq_grid_generator=generate_freq_grid_pytorch, + ) + + # Apply to a known q tensor so we can diff both the pe itself and the post-rotation output. + q = seeded_randn((BATCH, T, INNER_DIM), seed_offset=100) + q_rot = apply_rotary_emb(q, (cos, sin), LTXRopeType.SPLIT) + + m = {} + m["rope/positions"] = positions + m["rope/cos"] = cos + m["rope/sin"] = sin + m["rope/q_in"] = q + m["rope/q_rotated"] = q_rot + return m + + +def dump_scheduler(): + """LTX2Scheduler output for a few representative configurations. + Keys: 'schedule/tokens_{N}_steps_{S}' → sigma array of length S+1. + """ + scheduler = LTX2Scheduler() + cases = [ + # (tokens, steps, stretch, terminal) + (1024, 10, True, 0.1), # small latent (BASE_SHIFT anchor) + (1024, 30, True, 0.1), + (4096, 10, True, 0.1), # MAX_SHIFT anchor + (4096, 40, True, 0.1), # typical LTX-2 default + (2560, 30, True, 0.1), # interpolated + (4096, 8, False, 0.1), # no stretch path + ] + out = {} + for tokens, steps, stretch, terminal in cases: + # LTX2Scheduler expects a `latent` tensor to derive tokens from shape[2:]. + # Fake one with product(shape[2:]) == tokens. + fake_latent = torch.zeros(1, 1, tokens) + sigmas = scheduler.execute( + steps=steps, latent=fake_latent, + max_shift=2.05, base_shift=0.95, + stretch=stretch, terminal=terminal, + ) + key = f"schedule/tokens{tokens}_steps{steps}_stretch{int(stretch)}" + out[key] = sigmas.detach().float() + return out + + +def dump_adaln(): + """AdaLayerNormSingle: t → (modulation[B, coeff, dim], embedded[B, dim]).""" + torch.manual_seed(SEED + 2) + adaln = AdaLayerNormSingle(embedding_dim=INNER_DIM, embedding_coefficient=6).eval() + + # Fixed timestep σ ∈ (0, 1). Python applies *1000 externally; mirror that here. + sigma = torch.tensor([0.42], dtype=torch.float32) + t_scaled = sigma * 1000.0 + + with torch.no_grad(): + modulation, embedded = adaln(t_scaled, hidden_dtype=torch.float32) + + # Extract sub-weights for loading into C++. The isolated AdaLN test weights are not loaded into + # the full LTXRunner, so the prefix only needs to be unique w.r.t. the full-model weights. + sd = {f"adaln_standalone.{k}": v.detach().float() for k, v in adaln.state_dict().items()} + + return { + "adaln/sigma": sigma, + "adaln/t_scaled": t_scaled, + "adaln/modulation": modulation, + "adaln/embedded": embedded, + }, sd + + +def dump_full_model(): + """Tiny LTXModel (VideoOnly) forward, dumping per-block outputs.""" + torch.manual_seed(SEED + 3) + + # Stash a helper to tame magnitudes for parity testing. With default init, scale_shift_table is + # torch.empty(...) (uninitialised memory — random garbage) and many Linears have Kaiming init + # which, compounded across blocks with AdaLN * (1 + scale) + shift modulation, produces values + # that overflow fp32 (output becomes NaN). We don't care about the semantics of the weights — + # only that C++ and Python compute the SAME function on the SAME weights — so we replace them + # with bounded random values post-construction. + + model = LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=NUM_HEADS, + attention_head_dim=HEAD_DIM, + in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + num_layers=NUM_LAYERS, + cross_attention_dim=CROSS_ATTN_DIM, + norm_eps=NORM_EPS, + attention_type=AttentionFunction.PYTORCH, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + timestep_scale_multiplier=1000, + use_middle_indices_grid=True, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=False, + apply_gated_attention=False, + caption_projection=None, # cross_attention_dim == inner_dim so no projection needed + cross_attention_adaln=False, + ).eval() + + # Tame weights: small random gaussians with scale 1/sqrt(dim), scale_shift_table zeroed so the + # forward is well-conditioned even with two randomly-initialised blocks stacked. + with torch.no_grad(): + for name, p in model.named_parameters(): + if "scale_shift_table" in name: + p.zero_() + continue + # q_norm / k_norm RMSNorm weights must be ~1 to act as normalisers (not kill signal). + if name.endswith("q_norm.weight") or name.endswith("k_norm.weight"): + p.fill_(1.0) + continue + if p.dim() == 1: + # biases → zero + p.zero_() + else: + # Kaiming-ish: std ~ 1/sqrt(fan_in) + fan_in = p.shape[1] if p.dim() >= 2 else p.numel() + p.normal_(0.0, 1.0 / (fan_in ** 0.5)) + + # Synthetic inputs. + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + latent = seeded_randn((BATCH, IN_CHANNELS, F, H, W), seed_offset=200) + sigma = torch.tensor([0.5], dtype=torch.float32) + context = seeded_randn((BATCH, CONTEXT_LEN, CROSS_ATTN_DIM), seed_offset=300) + + # Build positions in (B, n_pos_dims, T, 2) middle-grid form. + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + # LTX's Modality carries the latent pre-patchify (shape [B, C, F, H, W] → flat [B, T, C]). + # patchify_proj is Linear(in_channels, inner_dim) so we need [B, T, C] for input. + latent_flat = latent.permute(0, 2, 3, 4, 1).reshape(BATCH, T, IN_CHANNELS) + + # For pure T2V, per-token timesteps = sigma broadcast. + timesteps = sigma.view(BATCH, 1).expand(BATCH, T).contiguous() + + # Positions shape the preprocessor wants is [B, 3, T] (no middle-grid pair dim) when + # use_middle_indices_grid=False, or [B, 3, T, 2] when True. Our positions are already the + # [B, 3, T, 2] form. Good. + modality = Modality( + latent=latent_flat, + sigma=sigma, + timesteps=timesteps, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + + # Instrument: intercept transformer_blocks outputs so we can dump per-block. + per_block_outputs = {} + orig_forwards = [] + for i, blk in enumerate(model.transformer_blocks): + orig = blk.forward + orig_forwards.append(orig) + + def make_capture(idx, original): + def capture(video=None, audio=None, perturbations=None): + out_video, out_audio = original(video=video, audio=audio, perturbations=perturbations) + per_block_outputs[f"block_{idx:02d}_out"] = out_video.x.detach().float().clone() + return out_video, out_audio + return capture + + blk.forward = make_capture(i, orig) + + with torch.no_grad(): + vx, _ = model(video=modality, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + # Also capture post-patchify result by running patchify_proj manually (same computation). + with torch.no_grad(): + patchified = model.patchify_proj(latent_flat) + tm_mod, tm_embedded = model.adaln_single(timesteps.flatten() * 1000.0, hidden_dtype=torch.float32) + tm_mod = tm_mod.view(BATCH, -1, tm_mod.shape[-1]) + tm_embedded = tm_embedded.view(BATCH, -1, tm_embedded.shape[-1]) + + # Also save the unflattened latent in [C, F, H, W] order (batch=1 squeezed). + # Memory layout: W innermost → matches ggml ne=[W, H, F, C] which is what LTXRunner::build_graph + # expects at its entry point. Convert by squeezing batch dim from the original [B=1, C, F, H, W]. + latent_unflat = latent.squeeze(0) # [C, F, H, W] + + # Velocity output comes out of the Python model as [B, T, C=out_channels]. Also save the + # unflattened [C, F, H, W] form so C++ can compare without reshaping. + vx_unflat = vx.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) # [C, F, H, W] + + tensors = { + "model/latent_in": latent_flat, + "model/latent_unflat": latent_unflat, + "model/sigma": sigma, + "model/timesteps_per_token": timesteps, + "model/context_in": context, + "model/positions": positions, + "model/patchify_out": patchified, + "model/adaln_modulation": tm_mod, + "model/adaln_embedded_timestep": tm_embedded, + "model/velocity_out": vx, + "model/velocity_out_unflat": vx_unflat, + } + for k, v in per_block_outputs.items(): + tensors[f"model/{k}"] = v + + # Use the sd.cpp convention: DiT weights live under "model.diffusion_model.". + # Pairs with LTXRunner's default prefix so the C++ loader reads names verbatim. + sd = {f"model.diffusion_model.{k}": v.detach().float() for k, v in model.state_dict().items()} + + # --- Single-step Euler parity ---------------------------------------------------------------- + # Starting from the noisy latent + the same velocity we just computed, run ONE deterministic + # Euler step using the LTX2Scheduler with 10 steps at σ=0.5 (which falls between sigmas[k] + # and sigmas[k+1] for some k — we pick the step endpoints manually so C++ gets exact inputs). + # This validates the (σ_next - σ) * v formula through the denoiser↔DiT integration boundary. + sched = LTX2Scheduler() + sched_sigmas = sched.execute(steps=10, latent=torch.zeros(1, 1, T), stretch=True, terminal=0.1) + + # Pick one adjacent sigma pair. sigmas[4] is reasonably mid-trajectory for 10 steps. + step_idx = 4 + sigma_cur = sched_sigmas[step_idx].item() + sigma_next = sched_sigmas[step_idx + 1].item() + + # The model was just run at σ=0.5; for the Euler test, re-run at σ_cur (a schedule value). + # The `vx` we already have is at σ=0.5 which doesn't match; redo the forward with sigma_cur. + timesteps_step = torch.tensor([sigma_cur], dtype=torch.float32).view(BATCH, 1).expand(BATCH, T).contiguous() + modality_step = Modality( + latent=latent_flat, + sigma=torch.tensor([sigma_cur], dtype=torch.float32), + timesteps=timesteps_step, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + with torch.no_grad(): + v_step, _ = model(video=modality_step, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + # Euler step: x_next = x + (σ_next - σ) * v (LTX-2 predicts velocity directly). + x_next = latent_flat + (sigma_next - sigma_cur) * v_step + + # Also dump the unflattened form for C++ convenience. + x_next_unflat = x_next.reshape(BATCH, F, H, W, IN_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) # [C, F, H, W] + v_step_unflat = v_step.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) + + tensors["euler/sigma_cur"] = torch.tensor([sigma_cur], dtype=torch.float32) + tensors["euler/sigma_next"] = torch.tensor([sigma_next], dtype=torch.float32) + tensors["euler/v_step"] = v_step + tensors["euler/v_step_unflat"] = v_step_unflat + tensors["euler/x_next"] = x_next + tensors["euler/x_next_unflat"] = x_next_unflat + + return tensors, sd + + +def dump_full_model_v2(num_layers: int = NUM_LAYERS, + zero_scale_shift_table: bool = True, + prefix: str = "model.diffusion_model_v2", + tensor_prefix: str = "v2model", + seed_offset: int = 4): + """Tiny LTXModel (VideoOnly) with V2 features enabled: + - cross_attention_adaln=True (adds prompt_scale_shift_table, prompt_adaln_single, + extends scale_shift_table to 9 coeffs, routes CA through apply_cross_attention_adaln) + - apply_gated_attention=True (adds to_gate_logits on attn1 and attn2) + State-dict is saved under `prefix` so multiple variants can coexist in the same file. + + Args: + num_layers: how many transformer blocks to stack. Deeper values exercise + cross-layer drift (e.g. the real 22B DiT has 48). + zero_scale_shift_table: if False, initialise all scale_shift_table / + prompt_scale_shift_table weights with bounded random values so the + modulation path (AdaLN multiply/shift + CA mod) is actually exercised + — the default True path is too well-conditioned to surface sign/layout + bugs in the (1+scale) and shift-kv branches. + """ + torch.manual_seed(SEED + seed_offset) + + model = LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=NUM_HEADS, + attention_head_dim=HEAD_DIM, + in_channels=IN_CHANNELS, + out_channels=OUT_CHANNELS, + num_layers=num_layers, + cross_attention_dim=CROSS_ATTN_DIM, + norm_eps=NORM_EPS, + attention_type=AttentionFunction.PYTORCH, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + timestep_scale_multiplier=1000, + use_middle_indices_grid=True, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=False, + apply_gated_attention=True, + caption_projection=None, + cross_attention_adaln=True, + ).eval() + + # Tame weights, same recipe as V1. The sst branch is optional: leaving it zero + # means AdaLN modulation degenerates to identity, which hides bugs in the + # (1 + scale) path and the CA-AdaLN shift_kv / scale_kv broadcast. + with torch.no_grad(): + for name, p in model.named_parameters(): + if "scale_shift_table" in name: + if zero_scale_shift_table: + p.zero_() + else: + # Keep magnitudes small so the stacked modulation doesn't explode + # across layers. scale_shift_table rows are added to a (0, 1]-ish + # AdaLN output; 0.05 keeps the post-modulation scale in ~[0.95, 1.05]. + p.normal_(0.0, 0.05) + continue + if name.endswith("q_norm.weight") or name.endswith("k_norm.weight"): + p.fill_(1.0) + continue + if p.dim() == 1: + p.zero_() + else: + fan_in = p.shape[1] if p.dim() >= 2 else p.numel() + p.normal_(0.0, 1.0 / (fan_in ** 0.5)) + + F, H, W = F_LAT, H_LAT, W_LAT + T = F * H * W + latent = seeded_randn((BATCH, IN_CHANNELS, F, H, W), seed_offset=400) + sigma = torch.tensor([0.5], dtype=torch.float32) + context = seeded_randn((BATCH, CONTEXT_LEN, CROSS_ATTN_DIM), seed_offset=500) + + positions = torch.zeros(BATCH, 3, T, 2, dtype=torch.float32) + idx = 0 + for f in range(F): + for h in range(H): + for w in range(W): + positions[0, 0, idx, 0] = f / FPS + positions[0, 0, idx, 1] = (f + 1) / FPS + positions[0, 1, idx, 0] = h + positions[0, 1, idx, 1] = h + 1 + positions[0, 2, idx, 0] = w + positions[0, 2, idx, 1] = w + 1 + idx += 1 + + latent_flat = latent.permute(0, 2, 3, 4, 1).reshape(BATCH, T, IN_CHANNELS) + timesteps = sigma.view(BATCH, 1).expand(BATCH, T).contiguous() + modality = Modality( + latent=latent_flat, + sigma=sigma, + timesteps=timesteps, + positions=positions, + context=context, + enabled=True, + context_mask=None, + attention_mask=None, + ) + + per_block_outputs = {} + for i, blk in enumerate(model.transformer_blocks): + orig = blk.forward + + def make_capture(idx, original): + def capture(video=None, audio=None, perturbations=None): + out_video, out_audio = original(video=video, audio=audio, perturbations=perturbations) + per_block_outputs[f"block_{idx:02d}_out"] = out_video.x.detach().float().clone() + return out_video, out_audio + return capture + + blk.forward = make_capture(i, orig) + + with torch.no_grad(): + vx, _ = model(video=modality, audio=None, perturbations=BatchedPerturbationConfig.empty(BATCH)) + + with torch.no_grad(): + patchified = model.patchify_proj(latent_flat) + tm_mod, tm_embedded = model.adaln_single(timesteps.flatten() * 1000.0, hidden_dtype=torch.float32) + tm_mod = tm_mod.view(BATCH, -1, tm_mod.shape[-1]) + tm_embedded = tm_embedded.view(BATCH, -1, tm_embedded.shape[-1]) + # V2 extra: prompt_adaln output driven by sigma (× scale_mult = 1000). + p_mod, _ = model.prompt_adaln_single( + (sigma * 1000.0).flatten(), hidden_dtype=torch.float32 + ) + p_mod = p_mod.view(BATCH, -1, p_mod.shape[-1]) + + latent_unflat = latent.squeeze(0) + vx_unflat = vx.reshape(BATCH, F, H, W, OUT_CHANNELS).permute(0, 4, 1, 2, 3).squeeze(0) + + tensors = { + f"{tensor_prefix}/latent_in": latent_flat, + f"{tensor_prefix}/latent_unflat": latent_unflat, + f"{tensor_prefix}/sigma": sigma, + f"{tensor_prefix}/timesteps_per_token": timesteps, + f"{tensor_prefix}/context_in": context, + f"{tensor_prefix}/positions": positions, + f"{tensor_prefix}/patchify_out": patchified, + f"{tensor_prefix}/adaln_modulation": tm_mod, + f"{tensor_prefix}/adaln_embedded_timestep": tm_embedded, + f"{tensor_prefix}/prompt_modulation": p_mod, + f"{tensor_prefix}/velocity_out": vx, + f"{tensor_prefix}/velocity_out_unflat": vx_unflat, + } + for k, v in per_block_outputs.items(): + tensors[f"{tensor_prefix}/{k}"] = v + + sd = {f"{prefix}.{k}": v.detach().float() for k, v in model.state_dict().items()} + return tensors, sd + + +# -------- Main -------- + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + + torch.use_deterministic_algorithms(False) # some ops (layernorm) aren't deterministic + torch.manual_seed(SEED) + + manifest = Manifest() + state_dict: Dict[str, torch.Tensor] = {} + + print("[1/4] RoPE …") + for name, t in dump_rope().items(): + manifest.add(name, t) + + print("[2/4] LTX2Scheduler …") + for name, t in dump_scheduler().items(): + manifest.add(name, t) + + print("[3/4] AdaLayerNormSingle …") + adaln_tensors, adaln_sd = dump_adaln() + for name, t in adaln_tensors.items(): + manifest.add(name, t) + state_dict.update(adaln_sd) + + print("[4/5] Full LTXModel (tiny, V1) …") + model_tensors, model_sd = dump_full_model() + for name, t in model_tensors.items(): + manifest.add(name, t) + state_dict.update(model_sd) + + print("[5/6] Full LTXModel (tiny, V2: cross_attention_adaln + apply_gated_attention) …") + model_v2_tensors, model_v2_sd = dump_full_model_v2() + for name, t in model_v2_tensors.items(): + manifest.add(name, t) + state_dict.update(model_v2_sd) + + # Deep V2: 8 layers + non-zero scale_shift_table so accumulated modulation drift + # surfaces. The original V2 dump is too gentle (only 2 layers, zeroed sst) to + # catch bugs that only matter when modulation is non-trivial. + print("[6/6] Full LTXModel (tiny, V2-deep: 8 layers, non-zero scale_shift_table) …") + v2_deep_tensors, v2_deep_sd = dump_full_model_v2( + num_layers=8, + zero_scale_shift_table=False, + prefix="model.diffusion_model_v2_deep", + tensor_prefix="v2deep", + seed_offset=7, + ) + for name, t in v2_deep_tensors.items(): + manifest.add(name, t) + state_dict.update(v2_deep_sd) + + # Safetensors requires contiguous CPU tensors. + sd_contig = {k: v.contiguous().cpu() for k, v in state_dict.items()} + save_file(sd_contig, str(OUT_DIR / "state_dict.safetensors")) + + manifest_path = OUT_DIR / "manifest.json" + manifest.dump(manifest_path) + + with (OUT_DIR / "tensor_names.txt").open("w") as f: + for name in sorted(state_dict.keys()): + t = state_dict[name] + f.write(f"{name}\t{list(t.shape)}\t{t.dtype}\n") + + print(f"Done. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(state_dict)} keys → {OUT_DIR}/state_dict.safetensors") + print(f"Manifest: {manifest_path}") + print(f"Name inventory: {OUT_DIR}/tensor_names.txt") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_s2d.py b/tests/ltx_parity/dump_s2d.py new file mode 100644 index 000000000..6a2f395c9 --- /dev/null +++ b/tests/ltx_parity/dump_s2d.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""Dump reference space-to-depth / depth-to-space outputs along each of the three +stride axes (W, H, T) as standalone test vectors. + +Each case applies a single-axis stride=2 split (the building block that will be +composed to give full 3D SpaceToDepth). We dump both the input and the expected +output so the C++ side can verify its ggml reshape+permute chain byte-exact. + +Output: /tmp/s2d_ref/tensors/*.bin + manifest.json + config.json. +Usage: /home/ilintar/venv/bin/python dump_s2d.py +""" + +from __future__ import annotations + +import json +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch +from einops import rearrange + +OUT_DIR = pathlib.Path("/tmp/s2d_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Distinct primes where possible so any mis-axis mixup shows up immediately. +B, C, T, H, W = 1, 3, 4, 6, 8 # after: (T/2, H, W), (T, H/2, W), (T, H, W/2) per case +FACTOR = 2 + + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + def add(self, name, t): self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + def dump(self, p): p.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save(t: torch.Tensor, name: str, mf: Manifest): + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{name}.bin") + mf.add(name, t) + + +def s2d_W(x: torch.Tensor, p3: int) -> torch.Tensor: + # [B, C, T, H, W*p3] -> [B, C*p3, T, H, W] + return rearrange(x, "b c t h (w p3) -> b (c p3) t h w", p3=p3) + + +def s2d_H(x: torch.Tensor, p2: int) -> torch.Tensor: + # [B, C, T, H*p2, W] -> [B, C*p2, T, H, W] + return rearrange(x, "b c t (h p2) w -> b (c p2) t h w", p2=p2) + + +def s2d_T(x: torch.Tensor, p1: int) -> torch.Tensor: + # [B, C, T*p1, H, W] -> [B, C*p1, T, H, W] + return rearrange(x, "b c (t p1) h w -> b (c p1) t h w", p1=p1) + + +def s2d_full(x: torch.Tensor, p1: int, p2: int, p3: int) -> torch.Tensor: + # [B, C, T*p1, H*p2, W*p3] -> [B, C*p1*p2*p3, T, H, W] + return rearrange(x, "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=p1, p2=p2, p3=p3) + + +def d2s_W(x: torch.Tensor, p3: int) -> torch.Tensor: + return rearrange(x, "b (c p3) t h w -> b c t h (w p3)", p3=p3) + + +def d2s_H(x: torch.Tensor, p2: int) -> torch.Tensor: + return rearrange(x, "b (c p2) t h w -> b c t (h p2) w", p2=p2) + + +def d2s_T(x: torch.Tensor, p1: int) -> torch.Tensor: + return rearrange(x, "b (c p1) t h w -> b c (t p1) h w", p1=p1) + + +def d2s_full(x: torch.Tensor, p1: int, p2: int, p3: int) -> torch.Tensor: + return rearrange(x, "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=p1, p2=p2, p3=p3) + + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + mf = Manifest() + torch.manual_seed(0) + + # --- Axis-W primitive --- + x_w = torch.randn(B, C, T, H, W * FACTOR) + save(x_w, "input_axisW", mf) + save(s2d_W(x_w, FACTOR), "expected_axisW", mf) + + # --- Axis-H primitive --- + x_h = torch.randn(B, C, T, H * FACTOR, W) + save(x_h, "input_axisH", mf) + save(s2d_H(x_h, FACTOR), "expected_axisH", mf) + + # --- Axis-T primitive --- + x_t = torch.randn(B, C, T * FACTOR, H, W) + save(x_t, "input_axisT", mf) + save(s2d_T(x_t, FACTOR), "expected_axisT", mf) + + # --- Full 3D (stride=(2,2,2)) composition --- + x_all = torch.randn(B, C, T * FACTOR, H * FACTOR, W * FACTOR) + save(x_all, "input_full222", mf) + save(s2d_full(x_all, FACTOR, FACTOR, FACTOR), "expected_full222", mf) + + # --- Stride=(1,2,2) (what compress_space_res uses) --- + x_122 = torch.randn(B, C, T, H * FACTOR, W * FACTOR) + save(x_122, "input_full122", mf) + save(s2d_full(x_122, 1, FACTOR, FACTOR), "expected_full122", mf) + + # --- Stride=(2,1,1) (compress_time_res) --- + x_211 = torch.randn(B, C, T * FACTOR, H, W) + save(x_211, "input_full211", mf) + save(s2d_full(x_211, FACTOR, 1, 1), "expected_full211", mf) + + # --- DepthToSpace (single-axis + composed) --- + # Input for axis primitives: [B, C_large, T, H, W] where C_large = C * factor. + dx_w = torch.randn(B, C * FACTOR, T, H, W) + save(dx_w, "dinput_axisW", mf) + save(d2s_W(dx_w, FACTOR), "dexpected_axisW", mf) + + dx_h = torch.randn(B, C * FACTOR, T, H, W) + save(dx_h, "dinput_axisH", mf) + save(d2s_H(dx_h, FACTOR), "dexpected_axisH", mf) + + dx_t = torch.randn(B, C * FACTOR, T, H, W) + save(dx_t, "dinput_axisT", mf) + save(d2s_T(dx_t, FACTOR), "dexpected_axisT", mf) + + dx_222 = torch.randn(B, C * (FACTOR ** 3), T, H, W) + save(dx_222, "dinput_full222", mf) + save(d2s_full(dx_222, FACTOR, FACTOR, FACTOR), "dexpected_full222", mf) + + dx_122 = torch.randn(B, C * (FACTOR ** 2), T, H, W) + save(dx_122, "dinput_full122", mf) + save(d2s_full(dx_122, 1, FACTOR, FACTOR), "dexpected_full122", mf) + + dx_211 = torch.randn(B, C * FACTOR, T, H, W) + save(dx_211, "dinput_full211", mf) + save(d2s_full(dx_211, FACTOR, 1, 1), "dexpected_full211", mf) + + # --- PixelNorm (dim=1 RMS) --- + eps = 1e-8 + pn_in = torch.randn(B, 5, T, H, W) # C=5 to exercise a non-power-of-2 channel + pn_out = pn_in / torch.sqrt((pn_in ** 2).mean(dim=1, keepdim=True) + eps) + save(pn_in, "pn_input", mf) + save(pn_out, "pn_expected", mf) + + # --- PerChannelStatistics --- + # Random mu and sigma (sigma > 0). Buffers shape [C] as in the real VAE. + c_pcs = 6 + pcs_in = torch.randn(B, c_pcs, T, H, W) + pcs_mu = torch.randn(c_pcs) + pcs_sigma = torch.rand(c_pcs) + 0.5 # keep away from zero + save(pcs_in, "pcs_input", mf) + save(pcs_mu, "pcs_mu", mf) + save(pcs_sigma, "pcs_sigma", mf) + save((pcs_in - pcs_mu.view(1, c_pcs, 1, 1, 1)) / pcs_sigma.view(1, c_pcs, 1, 1, 1), + "pcs_normalize_expected", mf) + save((pcs_in * pcs_sigma.view(1, c_pcs, 1, 1, 1)) + pcs_mu.view(1, c_pcs, 1, 1, 1), + "pcs_unnormalize_expected", mf) + + mf.dump(OUT_DIR / "manifest.json") + (OUT_DIR / "config.json").write_text(json.dumps({ + "B": B, "C": C, "T": T, "H": H, "W": W, "FACTOR": FACTOR, + "pn_C": 5, "pn_eps": eps, + "pcs_C": c_pcs, + }, indent=2)) + print(f"wrote {len(mf.entries)} tensors under {OUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/dump_vae.py b/tests/ltx_parity/dump_vae.py new file mode 100644 index 000000000..f8f48a9dd --- /dev/null +++ b/tests/ltx_parity/dump_vae.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +"""Dump tiny LTX-2 VAE reference tensors for C++/GGML parity testing. + +Strategy mirrors dump_reference.py / dump_gemma.py: instantiate a tiny VideoEncoder +and VideoDecoder with deterministic tamed weights, run one forward pass each on +fixed inputs, and save per-block intermediate outputs and the state_dict. + +Tiny config exercises one of each encoder block type (compress_space_res, +compress_time_res, res_x) and a matching decoder (res_x, compress_time, compress_space) +plus the output AdaLN + PerChannelStatistics (un)normalize. + +Usage: + /home/ilintar/venv/bin/python dump_vae.py +""" + +from __future__ import annotations + +import json +import math +import pathlib +from dataclasses import dataclass, field +from typing import Dict, List + +import numpy as np +import torch + +from safetensors.torch import save_file + +from ltx_core.model.video_vae.video_vae import VideoEncoder, VideoDecoder +from ltx_core.model.video_vae.enums import NormLayerType, LogVarianceType, PaddingModeType + +# -------- Config -------- + +SEED = 0 +OUT_DIR = pathlib.Path("/tmp/vae_ref") +TENSOR_DIR = OUT_DIR / "tensors" + +# Tiny VAE config. patch_size=2 (vs standard 4) to keep spatial dims small. +# encoder: compress_space_res(×2 ch) then compress_time_res(×2 ch) then res_x(1 layer). +# decoder: res_x(1 layer), compress_time, compress_space (reversed during construction). +IN_CHANNELS = 3 +LATENT_CHANNELS = 8 +DECODER_BASE_CH = 8 # decoder conv_in goes 128 -> 8 * 8 = 64 (with *8 multiplier) +PATCH_SIZE = 2 +NORM_LAYER = NormLayerType.PIXEL_NORM +LOG_VAR = LogVarianceType.UNIFORM +PADDING_ENC = PaddingModeType.ZEROS +PADDING_DEC = PaddingModeType.REFLECT + +# Video shape: 1 + 8*k frames required by encoder's validator. 1 + 8*1 = 9 → F=9. +# Spatial must divide by (patch_size * 2 * 2) = 8 for one compress_space_res + one compress_time_res. +# H = W = 16 is the minimum that divides 8 cleanly after patchify. +BATCH, F_IN, H_IN, W_IN = 1, 9, 16, 16 + +DECODE_TIMESTEP = 0.05 # Gemma/LTX-2 conventional decoder timestep + + +# -------- Utility -------- + +@dataclass +class Manifest: + entries: List[Dict] = field(default_factory=list) + + def add(self, name: str, t: torch.Tensor): + self.entries.append({"name": name, "shape": list(t.shape), "dtype": "f32"}) + + def dump(self, path: pathlib.Path): + path.write_text(json.dumps({"entries": self.entries}, indent=2)) + + +def save_tensor(t: torch.Tensor, name: str, manifest: Manifest): + safe = name.replace("/", "__") + arr = t.detach().to(torch.float32).contiguous().cpu().numpy() + arr.tofile(TENSOR_DIR / f"{safe}.bin") + manifest.add(name, t) + + +def tame_(model: torch.nn.Module): + """Deterministic, finite weights. Reuses the pattern from dump_reference.py. + + - 1D params (biases, norm weights, scale_shift_tables, per_channel_scale*): + zero-initialized. Gemma-style convention where (1+w) is used as the effective + scale works the same for VAE's ResnetBlock3D AdaLN (hidden * (1 + scale) + shift). + - 2D/3D/4D/5D params (linears, convs): Kaiming-ish with std=1/sqrt(fan_in). + - PerChannelStatistics buffers: std-of-means = 1.0, mean-of-means = 0.0 so + normalize/un_normalize become identity + scale. + """ + g = torch.Generator().manual_seed(SEED) + with torch.no_grad(): + for name, p in model.named_parameters(): + if p.dim() <= 1: + # 0-D (scalars like timestep_scale_multiplier) or 1-D (biases, norm weights, + # scale_shift_tables, per_channel_scale*) — zero init. + p.zero_() + # `timestep_scale_multiplier` needs to be exactly the trained value (1000.0) + # to match the denoiser scale, but a zero multiplier just zeroes the timestep + # embedding; the parity path still exercises the rest of the block math. + else: + fan_in = max(1, p.numel() // p.shape[0]) + std = 1.0 / math.sqrt(fan_in) + p.normal_(mean=0.0, std=std, generator=g) + # PerChannelStatistics is a buffer (not a parameter), init separately: + for name, buf in model.named_buffers(): + if "std-of-means" in name: + buf.fill_(1.0) + elif "mean-of-means" in name: + buf.fill_(0.0) + + +def build_encoder() -> VideoEncoder: + return VideoEncoder( + convolution_dimensions=3, + in_channels=IN_CHANNELS, + out_channels=LATENT_CHANNELS, + encoder_blocks=[ + ("compress_space_res", {"multiplier": 2}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 1}), + ], + patch_size=PATCH_SIZE, + norm_layer=NORM_LAYER, + latent_log_var=LOG_VAR, + encoder_spatial_padding_mode=PADDING_ENC, + ) + + +def build_decoder() -> VideoDecoder: + # Encoder reduces temporal by 2 and spatial by patch_size*2 = 4. Decoder matches inverse. + return VideoDecoder( + convolution_dimensions=3, + in_channels=LATENT_CHANNELS, + out_channels=IN_CHANNELS, + decoder_blocks=[ + # order is reversed inside VideoDecoder ctor — this is the encoder-side order: + ("compress_space", {"multiplier": 1}), + ("compress_time", {"multiplier": 1}), + ("res_x", {"num_layers": 1}), + ], + patch_size=PATCH_SIZE, + norm_layer=NORM_LAYER, + causal=False, + timestep_conditioning=True, + decoder_spatial_padding_mode=PADDING_DEC, + base_channels=DECODER_BASE_CH, + ) + + +# -------- Main -------- + +def main(): + OUT_DIR.mkdir(parents=True, exist_ok=True) + TENSOR_DIR.mkdir(parents=True, exist_ok=True) + torch.manual_seed(SEED) + + encoder = build_encoder().eval() + decoder = build_decoder().eval() + + tame_(encoder) + tame_(decoder) + + # Input video: (B=1, C=3, F=9, H=16, W=16). + rng = np.random.default_rng(SEED) + video_np = rng.standard_normal((BATCH, IN_CHANNELS, F_IN, H_IN, W_IN), dtype=np.float32) + video = torch.from_numpy(video_np).clone() + + manifest = Manifest() + save_tensor(video, "video_in", manifest) + print(f"input: shape={tuple(video.shape)}") + + # --- Encoder forward --- + with torch.no_grad(): + x = video + # Replicate VideoEncoder.forward manually so we can cache intermediates. + from ltx_core.model.video_vae.ops import patchify + x = patchify(x, patch_size_hw=PATCH_SIZE, patch_size_t=1) + save_tensor(x, "enc_post_patchify", manifest) + + x = encoder.conv_in(x) + save_tensor(x, "enc_post_conv_in", manifest) + + for i, blk in enumerate(encoder.down_blocks): + x = blk(x) + save_tensor(x, f"enc_block_{i}", manifest) + + x = encoder.conv_norm_out(x) + save_tensor(x, "enc_post_norm", manifest) + x = encoder.conv_act(x) + x = encoder.conv_out(x) + save_tensor(x, "enc_post_conv_out", manifest) + + # Replicate UNIFORM latent_log_var path: means = x[:, :-1], logvar = x[:, -1:]. + if LOG_VAR == LogVarianceType.UNIFORM: + means = x[:, :-1, ...] + logvar = x[:, -1:, ...] + # (We save just the means and the final normalized latent; don't need logvar for parity.) + save_tensor(means, "enc_means_preNorm", manifest) + latent = encoder.per_channel_statistics.normalize(means) + save_tensor(latent, "latent", manifest) + else: + raise RuntimeError("only UNIFORM supported in dumper") + + print(f"latent: shape={tuple(latent.shape)} mean={latent.mean().item():.4f} std={latent.std().item():.4f}") + + # --- Decoder forward (deterministic path; no noise, fixed timestep) --- + timestep = torch.full((BATCH,), DECODE_TIMESTEP, dtype=torch.float32) + with torch.no_grad(): + y = decoder.per_channel_statistics.un_normalize(latent) + save_tensor(y, "dec_post_unnorm", manifest) + + # Match the real decoder.forward: self.causal=False is set by the configurator, + # so every conv call uses causal=False. Earlier versions of this dumper relied + # on the default causal=True which diverged from actual behavior and masked a + # conv1/conv2 mismatch in the C++ port. + y = decoder.conv_in(y, causal=False) + save_tensor(y, "dec_post_conv_in", manifest) + + # TimestepEmbedder probe: feed the exact `timestep` used below, save the 256-dim + # result so the C++ side can byte-diff its TimestepEmbedder output against Python's. + # Uses the inner time_embedder (embedding_dim=256) from the res_x block. + te_probe = decoder.up_blocks[0].time_embedder( + timestep=timestep.flatten(), hidden_dtype=y.dtype) + save_tensor(te_probe, "te_probe_up0", manifest) + + # Intermediate after each up_block (reversed decoder config). + # Probe INSIDE the first res_x block: dump the pixel_norm(conv_in_output) to verify + # the norm path is byte-exact. This is the Python `hidden_states = self.norm1(x)` + # inside the first ResnetBlock3D of up_blocks[0]. + from ltx_core.model.common.normalization import PixelNorm + probe_block = decoder.up_blocks[0].res_blocks[0] + y_pre = y # still conv_in output here; save a copy. + y_norm1 = probe_block.norm1(y_pre) + save_tensor(y_norm1, "dec_resblock0_post_norm1", manifest) + # Also save post_adaln1 (just the modulation, no silu/conv yet). + ts_embed_block = decoder.up_blocks[0].time_embedder( + timestep=timestep.flatten(), hidden_dtype=y_pre.dtype + ).view(BATCH, -1, 1, 1, 1) + ada_probe = probe_block.scale_shift_table[None, ..., None, None, None] + ts_embed_block.reshape( + BATCH, 4, -1, 1, 1, 1 + ) + sh1, sc1, sh2, sc2 = ada_probe.unbind(dim=1) + y_adaln1 = y_norm1 * (1 + sc1) + sh1 + save_tensor(y_adaln1, "dec_resblock0_post_adaln1", manifest) + y_silu1 = probe_block.non_linearity(y_adaln1) + y_conv1 = probe_block.conv1(y_silu1, causal=False) + save_tensor(y_conv1, "dec_resblock0_post_conv1", manifest) + y_norm2 = probe_block.norm2(y_conv1) + save_tensor(y_norm2, "dec_resblock0_post_norm2", manifest) + + # Build the timestep embedding that UNetMidBlock3D would use internally for + # the res_x block. The scale multiplier is a learned scalar (we zero-inited it). + # The parity comparison only verifies the *forward* math; if the multiplier is + # 0 then the time embedding collapses. That's fine since we verify tracewise. + # Inject a timestep only when calling the block (passed through forward()). + + # Replicate VideoDecoder.forward partial path. + # Important: the decoder's up_blocks list is REVERSED of the config list. + # Our config: [compress_space, compress_time, res_x]. After reversing: + # [res_x, compress_time, compress_space]. So up_blocks[0] is the res_x. + + # Timestep scale is used inside last_time_embedder; but res_x UNetMidBlock3D + # has its own time_embedder. Pass raw timestep; each block handles scaling. + + for i, blk in enumerate(decoder.up_blocks): + # Only res_x (UNetMidBlock3D) accepts timestep; up/down sample blocks don't. + from ltx_core.model.video_vae.resnet import UNetMidBlock3D + if isinstance(blk, UNetMidBlock3D): + y = blk(y, causal=False, timestep=timestep) + else: + y = blk(y, causal=False) + save_tensor(y, f"dec_block_{i}", manifest) + + # Final AdaLN output + conv_norm_out: this is the `last_scale_shift_table` + time_embedder path. + ada = decoder.last_scale_shift_table[None, ..., None, None, None] + decoder.last_time_embedder( + timestep=(timestep * decoder.timestep_scale_multiplier).flatten(), + hidden_dtype=y.dtype, + ).view(BATCH, 2, -1, 1, 1, 1) + shift, scale = ada.unbind(dim=1) + y = decoder.conv_norm_out(y) + save_tensor(y, "dec_post_pixel_norm", manifest) + y = y * (1 + scale) + shift + save_tensor(y, "dec_post_ada", manifest) + y = decoder.conv_act(y) + y = decoder.conv_out(y, causal=False) + save_tensor(y, "dec_post_conv_out", manifest) + + from ltx_core.model.video_vae.ops import unpatchify + y = unpatchify(y, patch_size_hw=PATCH_SIZE, patch_size_t=1) + save_tensor(y, "video_out", manifest) + + print(f"decoded: shape={tuple(y.shape)} mean={y.mean().item():.4f} std={y.std().item():.4f}") + + # --- State dict: concatenate encoder + decoder + per_channel_statistics under "vae." prefix. --- + prefixed = {} + for k, v in encoder.state_dict().items(): + prefixed[f"vae.encoder.{k}"] = v.to(torch.float32).contiguous() + for k, v in decoder.state_dict().items(): + prefixed[f"vae.decoder.{k}"] = v.to(torch.float32).contiguous() + # PerChannelStatistics is registered inside both encoder & decoder AND also dumped under + # a top-level `vae.per_channel_statistics.*` path (matching the real checkpoint convention, + # per VAE_ENCODER_COMFY_KEYS_FILTER). We keep all three copies so encoder/decoder + # blocks can load from either the nested or the canonical path. + pcs = encoder.per_channel_statistics + for bufname, buf in pcs.named_buffers(): + # .clone() to sever storage sharing with the nested copies — safetensors + # refuses to dump multiple keys pointing at the same underlying buffer. + prefixed[f"vae.per_channel_statistics.{bufname}"] = buf.detach().to(torch.float32).clone().contiguous() + + save_file(prefixed, str(OUT_DIR / "state_dict.safetensors")) + (OUT_DIR / "tensor_names.txt").write_text("\n".join(sorted(prefixed.keys())) + "\n") + manifest.dump(OUT_DIR / "manifest.json") + + (OUT_DIR / "config.json").write_text(json.dumps({ + "in_channels": IN_CHANNELS, + "latent_channels": LATENT_CHANNELS, + "decoder_base_ch": DECODER_BASE_CH, + "patch_size": PATCH_SIZE, + "norm_layer": NORM_LAYER.value, + "log_var": LOG_VAR.value, + "batch": BATCH, + "frames": F_IN, + "height": H_IN, + "width": W_IN, + "decode_timestep": DECODE_TIMESTEP, + "encoder_blocks": [ + ["compress_space_res", {"multiplier": 2}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 1}], + ], + "decoder_blocks": [ + ["compress_space", {"multiplier": 1}], + ["compress_time", {"multiplier": 1}], + ["res_x", {"num_layers": 1}], + ], + }, indent=2)) + + print(f"\nDone. Wrote {len(manifest.entries)} tensors under {OUT_DIR}.") + print(f"State dict: {len(prefixed)} keys → {OUT_DIR}/state_dict.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tests/ltx_parity/test_attn_chain_parity.cpp b/tests/ltx_parity/test_attn_chain_parity.cpp new file mode 100644 index 000000000..277a758cc --- /dev/null +++ b/tests/ltx_parity/test_attn_chain_parity.cpp @@ -0,0 +1,165 @@ +// Tests whether forcing softmax to run on CPU (and feeding the result into a +// CUDA mul_mat) closes the V·softmax matmul drift in Gemma's attention. +// +// Loads real Gemma layer-0 _attn_kq_masked.bin and _attn_v.bin, then computes: +// ref_cpu = mul_mat(v, softmax_cpu(kq)) on CPU +// pure_cuda = mul_mat(v, softmax_cuda(kq)) on CUDA +// hybrid = mul_mat(v, softmax_cpu(kq) → CUDA) softmax on CPU, mul_mat on CUDA +// +// If `hybrid` matches `ref_cpu` much better than `pure_cuda` does, then forcing +// softmax to CPU is the win we're looking for. + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#include "ggml.h" + +static std::vector run_softmax(ggml_backend_t backend, + const std::vector& kq_data, + int K, int N, int B) { + ggml_init_params params = {16 * 1024 * 1024, nullptr, true}; + ggml_context* ctx = ggml_init(params); + auto src = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K, N, B); + auto dst = ggml_soft_max(ctx, src); + ggml_set_output(dst); + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, dst); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + ggml_backend_tensor_set(src, kq_data.data(), 0, ggml_nbytes(src)); + ggml_gallocr_t ga = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(ga, gf); + ggml_backend_graph_compute(backend, gf); + + std::vector out(ggml_nelements(dst)); + ggml_backend_tensor_get(dst, out.data(), 0, ggml_nbytes(dst)); + ggml_gallocr_free(ga); + ggml_backend_buffer_free(buf); + ggml_free(ctx); + return out; +} + +static std::vector run_mul_mat(ggml_backend_t backend, + const std::vector& src0_data, + const std::vector& src1_data, + int K, int M, int N, int B, int r2, + bool prec_f32) { + ggml_init_params params = {16 * 1024 * 1024, nullptr, true}; + ggml_context* ctx = ggml_init(params); + auto src0 = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K, M, B / r2); + auto src1 = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K, N, B); + auto dst = ggml_mul_mat(ctx, src0, src1); + if (prec_f32) ggml_mul_mat_set_prec(dst, GGML_PREC_F32); + ggml_set_output(dst); + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, dst); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + ggml_backend_tensor_set(src0, src0_data.data(), 0, ggml_nbytes(src0)); + ggml_backend_tensor_set(src1, src1_data.data(), 0, ggml_nbytes(src1)); + ggml_gallocr_t ga = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(ga, gf); + ggml_backend_graph_compute(backend, gf); + + std::vector out(ggml_nelements(dst)); + ggml_backend_tensor_get(dst, out.data(), 0, ggml_nbytes(dst)); + ggml_gallocr_free(ga); + ggml_backend_buffer_free(buf); + ggml_free(ctx); + return out; +} + +static std::vector load_bin(const std::string& path) { + FILE* fp = std::fopen(path.c_str(), "rb"); + if (!fp) { std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); std::exit(1); } + std::fseek(fp, 0, SEEK_END); + long sz = std::ftell(fp); + std::fseek(fp, 0, SEEK_SET); + std::vector out(sz / sizeof(float)); + std::fread(out.data(), 1, sz, fp); + std::fclose(fp); + return out; +} + +struct Stats { + double max_abs = 0; + double mean_abs = 0; + double cpu_mag = 0; +}; + +static Stats diff(const std::vector& a, const std::vector& b) { + Stats s; + double sum_abs = 0, sum_a_abs = 0; + for (size_t i = 0; i < a.size(); ++i) { + double d = std::fabs((double) a[i] - (double) b[i]); + if (d > s.max_abs) s.max_abs = d; + sum_abs += d; + sum_a_abs += std::fabs((double) a[i]); + } + s.mean_abs = sum_abs / a.size(); + s.cpu_mag = sum_a_abs / a.size(); + return s; +} + +int main() { + const std::string dir = std::getenv("SD_GEMMA_TAPS_DIR") ? std::getenv("SD_GEMMA_TAPS_DIR") : "/tmp/gemma_taps"; + auto kq_masked = load_bin(dir + "/_attn_kq_masked.bin"); + auto v = load_bin(dir + "/_attn_v.bin"); + + // Shapes: kq=[128, 128, 16] (K=L_k, N=L_q, B=n_head*N), v=[128, 256, 8] (K=L_k, M=d_head, B/r2=n_kv*N) + const int K = 128, N = 128, B = 16; + const int M = 256, B0 = 8; + if ((int) kq_masked.size() != K * N * B) { + std::fprintf(stderr, "kq_masked size mismatch: got %zu floats\n", kq_masked.size()); return 1; + } + if ((int) v.size() != K * M * B0) { + std::fprintf(stderr, "v size mismatch: got %zu floats\n", v.size()); return 1; + } + + ggml_backend_t cpu_bk = ggml_backend_cpu_init(); + int cuda_dev = std::getenv("SD_CUDA_DEVICE") ? std::atoi(std::getenv("SD_CUDA_DEVICE")) : 1; + ggml_backend_t cuda_bk = ggml_backend_cuda_init(cuda_dev); + if (!cuda_bk) { std::fprintf(stderr, "fatal: CUDA init failed\n"); return 1; } + + std::printf("Step 1: softmax on CPU\n"); + auto sm_cpu = run_softmax(cpu_bk, kq_masked, K, N, B); + std::printf("Step 2: softmax on CUDA\n"); + auto sm_cuda = run_softmax(cuda_bk, kq_masked, K, N, B); + { + auto s = diff(sm_cpu, sm_cuda); + std::printf(" softmax CPU vs CUDA: max=%.6e mean=%.6e cpu_mag=%.6e\n", + s.max_abs, s.mean_abs, s.cpu_mag); + } + + std::printf("Step 3: kqv on CPU using softmax_cpu (reference)\n"); + auto kqv_cpu = run_mul_mat(cpu_bk, v, sm_cpu, K, M, N, B, /*r2=*/2, true); + + std::printf("Step 4a: kqv on CUDA using softmax_cuda (current)\n"); + auto kqv_cuda_pure = run_mul_mat(cuda_bk, v, sm_cuda, K, M, N, B, 2, true); + { + auto s = diff(kqv_cpu, kqv_cuda_pure); + std::printf(" kqv CPU vs CUDA(pure): max=%.6e mean=%.6e cpu_mag=%.6e\n", + s.max_abs, s.mean_abs, s.cpu_mag); + } + + std::printf("Step 4b: kqv on CUDA using softmax_cpu (hybrid)\n"); + auto kqv_cuda_hybrid = run_mul_mat(cuda_bk, v, sm_cpu, K, M, N, B, 2, true); + { + auto s = diff(kqv_cpu, kqv_cuda_hybrid); + std::printf(" kqv CPU vs CUDA(hybrid): max=%.6e mean=%.6e cpu_mag=%.6e\n", + s.max_abs, s.mean_abs, s.cpu_mag); + } + + ggml_backend_free(cuda_bk); + ggml_backend_free(cpu_bk); + return 0; +} diff --git a/tests/ltx_parity/test_av_block_parity.cpp b/tests/ltx_parity/test_av_block_parity.cpp new file mode 100644 index 000000000..f2e2b86cb --- /dev/null +++ b/tests/ltx_parity/test_av_block_parity.cpp @@ -0,0 +1,430 @@ +// LTX-2 AV transformer block parity test. +// +// Loads weights + inputs + outputs dumped by tests/ltx_parity/dump_av_block.py, +// constructs an LTXTransformerBlock with the same flags (cross_attention_adaln, +// apply_gated_attention, audio dims), runs forward_av, and diffs against the +// python reference outputs. +// +// Tolerances: F32 CPU backend should match python torch.float32 to ~1e-5 abs / +// ~1e-4 rel. RoPE INTERLEAVED is well-tested in the existing LTX parity. The +// gated-attention sigmoid path adds a tiny amount of drift (per-head gate * 2) +// — still within tolerance for one block. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx.hpp" +#include "ltx_rope.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +// Tiny no-deps JSON-ish reader. dump_av_block.py emits a regular JSON object; +// we only need to look up nested fields like manifest["weights"][name]["shape"] +// and ["path"], plus manifest["config"][k]. To avoid pulling in a JSON library +// we shell out to /home/ilintar/venv/bin/python and read the manifest values +// directly via the load helpers below — the dump script's path is fixed. + +const std::string REF_DIR = "/tmp/ltx_av_block_ref"; + +// Read raw f32 bytes from a file into an sd::Tensor with the given ggml ne. +sd::Tensor load_raw(const std::string& path, const std::vector& ne) { + sd::Tensor t(ne); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s\n", path.c_str()); + std::exit(2); + } + return t; +} + +struct DiffStats { + double max_abs = 0.0; + double mean_abs = 0.0; + double max_rel = 0.0; + int64_t n = 0; +}; + +DiffStats diff_f32(const float* a, const float* b, int64_t n) { + DiffStats s; s.n = n; + double sum = 0.0; + for (int64_t i = 0; i < n; ++i) { + double abs_err = std::fabs(double(a[i]) - double(b[i])); + double rel_err = abs_err / (std::fabs(double(b[i])) + 1e-8); + if (abs_err > s.max_abs) s.max_abs = abs_err; + if (rel_err > s.max_rel) s.max_rel = rel_err; + sum += abs_err; + } + s.mean_abs = sum / std::max(1, n); + return s; +} + +// Build a [inner_dim, T, 2] PE tensor from a (cos, sin) pair where each input is +// [B=1, T, inner_dim] in raw f32 row-major. The output's ggml ne layout is +// (inner_dim fastest, T mid, slice index outer): slice 0 = cos bytes, slice 1 = sin. +sd::Tensor build_pe_packed(const std::string& cos_path, + const std::string& sin_path, + int64_t inner_dim, int64_t T) { + sd::Tensor pe({inner_dim, T, 2}); + auto cos = load_raw(cos_path, {inner_dim, T, 1}); + auto sin = load_raw(sin_path, {inner_dim, T, 1}); + std::memcpy(pe.data(), cos.data(), cos.numel() * sizeof(float)); + std::memcpy(pe.data() + cos.numel(), sin.data(), sin.numel() * sizeof(float)); + return pe; +} + +// Custom GGMLRunner that owns a single LTXTransformerBlock with audio enabled +// and runs forward_av once with externally-supplied inputs. +struct AVBlockParityRunner : public GGMLRunner { + LTX::LTXTransformerBlock block; + + sd::Tensor vx, ax; + sd::Tensor v_ctx, a_ctx; + sd::Tensor v_mod, a_mod; + sd::Tensor v_pmod, a_pmod; // prompt modulation + sd::Tensor v_pe, a_pe; + sd::Tensor v_cpe, a_cpe; + sd::Tensor v_css, a_css; // cross_scale_shift_modulation + sd::Tensor v_cg, a_cg; // cross_gate_modulation + + sd::Tensor result_v, result_a; // captured after compute + + AVBlockParityRunner(ggml_backend_t backend, + int64_t v_dim, int v_h, int v_hd, + int64_t a_dim, int a_h, int a_hd, + int64_t v_ctx, int64_t a_ctx, + bool cross_attention_adaln, bool apply_gated_attention, + LTX::RopeType rope_type, float norm_eps) + : GGMLRunner(backend, /*offload_params_to_cpu=*/false), + block(v_dim, v_h, v_hd, v_ctx, + cross_attention_adaln, apply_gated_attention, norm_eps, + rope_type, + a_dim, a_h, a_hd, a_ctx) { + block.init(params_ctx, /*tensor_storage_map=*/{}, /*prefix=*/""); + } + + std::string get_desc() override { return "AVBlockParityRunner"; } + + ggml_cgraph* build_graph() { + auto gf = new_graph_custom(LTX::LTX_GRAPH_SIZE); + + ggml_tensor* g_vx = make_input(vx); + ggml_tensor* g_ax = make_input(ax); + ggml_tensor* g_vctx = make_input(v_ctx); + ggml_tensor* g_actx = make_input(a_ctx); + ggml_tensor* g_vmod = make_input(v_mod); + ggml_tensor* g_amod = make_input(a_mod); + ggml_tensor* g_vpmod = v_pmod.empty() ? nullptr : make_input(v_pmod); + ggml_tensor* g_apmod = a_pmod.empty() ? nullptr : make_input(a_pmod); + ggml_tensor* g_vpe = make_input(v_pe); + ggml_tensor* g_ape = make_input(a_pe); + ggml_tensor* g_vcpe = make_input(v_cpe); + ggml_tensor* g_acpe = make_input(a_cpe); + ggml_tensor* g_vcss = make_input(v_css); + ggml_tensor* g_acss = make_input(a_css); + ggml_tensor* g_vcg = make_input(v_cg); + ggml_tensor* g_acg = make_input(a_cg); + + LTX::LTX2AVModalityArgs vargs; + vargs.x = g_vx; vargs.context = g_vctx; vargs.modulation = g_vmod; + vargs.pe = g_vpe; vargs.cross_pe = g_vcpe; + vargs.prompt_modulation = g_vpmod; + vargs.cross_scale_shift_modulation = g_vcss; + vargs.cross_gate_modulation = g_vcg; + + LTX::LTX2AVModalityArgs aargs; + aargs.x = g_ax; aargs.context = g_actx; aargs.modulation = g_amod; + aargs.pe = g_ape; aargs.cross_pe = g_acpe; + aargs.prompt_modulation = g_apmod; + aargs.cross_scale_shift_modulation = g_acss; + aargs.cross_gate_modulation = g_acg; + + auto rctx = get_context(); + auto outs = block.forward_av(&rctx, vargs, aargs); + + // Force both to live by adding to the graph; we'll fetch them by name. + ggml_set_name(outs.first, "av_video_out"); + ggml_set_name(outs.second, "av_audio_out"); + ggml_build_forward_expand(gf, outs.first); + ggml_build_forward_expand(gf, outs.second); + + // Cache so we can fetch them post-compute via get_cache_tensor_by_name. + cache("av_video_out", outs.first); + cache("av_audio_out", outs.second); + + return gf; + } + + bool compute_and_capture(int n_threads, + int64_t v_t, int64_t v_dim, + int64_t a_t, int64_t a_dim) { + auto gg = [this]() { return build_graph(); }; + compute(gg, n_threads, /*free_compute_buffer_immediately=*/false, + /*no_return=*/true); + + ggml_tensor* tv = get_cache_tensor_by_name("av_video_out"); + ggml_tensor* ta = get_cache_tensor_by_name("av_audio_out"); + if (tv == nullptr || ta == nullptr) { + std::fprintf(stderr, "fatal: missing cached output tensor(s)\n"); + return false; + } + result_v = sd::Tensor({v_dim, v_t, 1}); + result_a = sd::Tensor({a_dim, a_t, 1}); + ggml_backend_tensor_get(tv, result_v.data(), 0, ggml_nbytes(tv)); + ggml_backend_tensor_get(ta, result_a.data(), 0, ggml_nbytes(ta)); + return true; + } + + // Load a python-dumped weight by name into the corresponding block tensor. + // Returns false if the name is unknown. + bool load_weight_into(const std::string& name, const std::string& path) { + std::map all; + block.get_param_tensors(all, /*prefix=*/""); + auto it = all.find(name); + if (it == all.end()) return false; + ggml_tensor* t = it->second; + std::vector buf(ggml_nelements(t)); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) return false; + f.read(reinterpret_cast(buf.data()), + static_cast(buf.size() * sizeof(float))); + ggml_backend_tensor_set(t, buf.data(), 0, ggml_nbytes(t)); + return true; + } +}; + +// Read manifest.json by shelling to python (avoids a JSON dependency). +struct Manifest { + std::map> weights_shape; + std::map weights_path; + std::map> inputs_shape; + std::map inputs_path; + std::map> outputs_shape; + std::map outputs_path; + std::map config_str; +}; + +// Parse manifest.json via a one-shot python helper that prints simple lines. +// Format: SECTION:NAME:SHAPE,COMMA,SEP:PATH +Manifest read_manifest(const std::string& dir) { + std::string cmd = "/home/ilintar/venv/bin/python -c '" + "import json,sys\n" + "m=json.load(open(\"" + dir + "/manifest.json\"))\n" + "for k,v in m[\"config\"].items(): print(\"C:\"+k+\":\"+str(v))\n" + "for sec in [\"weights\",\"inputs\",\"outputs\"]:\n" + " for n,d in m[sec].items():\n" + " sh=\",\".join(str(x) for x in d[\"shape\"])\n" + " print(sec[0].upper()+\":\"+n+\":\"+sh+\":\"+d[\"path\"])\n" + "' 2>/dev/null"; + FILE* p = popen(cmd.c_str(), "r"); + if (!p) { std::fprintf(stderr, "fatal: popen failed\n"); std::exit(2); } + Manifest m; + char line[4096]; + while (std::fgets(line, sizeof(line), p)) { + std::string s(line); + if (!s.empty() && s.back() == '\n') s.pop_back(); + if (s.size() < 3 || s[1] != ':') continue; + char tag = s[0]; + std::string rest = s.substr(2); + size_t c1 = rest.find(':'); + if (c1 == std::string::npos) continue; + std::string name = rest.substr(0, c1); + std::string after = rest.substr(c1 + 1); + if (tag == 'C') { + m.config_str[name] = after; + continue; + } + size_t c2 = after.find(':'); + if (c2 == std::string::npos) continue; + std::string shape_str = after.substr(0, c2); + std::string path = dir + "/" + after.substr(c2 + 1); + std::vector shape; + std::stringstream ss(shape_str); + std::string tok; + while (std::getline(ss, tok, ',')) { + if (!tok.empty()) shape.push_back(std::stoll(tok)); + } + if (tag == 'W') { m.weights_shape[name] = shape; m.weights_path[name] = path; } + else if (tag == 'I') { m.inputs_shape[name] = shape; m.inputs_path[name] = path; } + else if (tag == 'O') { m.outputs_shape[name] = shape; m.outputs_path[name] = path; } + } + pclose(p); + return m; +} + +int parse_int(const std::map& cfg, const std::string& k) { + auto it = cfg.find(k); + if (it == cfg.end()) { std::fprintf(stderr, "fatal: missing config %s\n", k.c_str()); std::exit(2); } + return std::stoi(it->second); +} +bool parse_bool(const std::map& cfg, const std::string& k) { + auto it = cfg.find(k); + if (it == cfg.end()) return false; + return it->second == "True" || it->second == "true" || it->second == "1"; +} + +} // namespace + + +int main() { + Manifest m = read_manifest(REF_DIR); + if (m.config_str.empty()) { + std::fprintf(stderr, "fatal: empty manifest. Run dump_av_block.py first.\n"); + return 2; + } + + // ---- Pull dims & flags from config (parsing a few python repr forms) ---- + auto parse_dict = [&](const std::string& key, const char* sub) -> int { + // matches "{'dim': 128, ..." → grab the value after sub. + auto it = m.config_str.find(key); + if (it == m.config_str.end()) { std::fprintf(stderr, "fatal: %s\n", key.c_str()); std::exit(2); } + std::string s = it->second; + std::string needle = std::string("'") + sub + "': "; + size_t p = s.find(needle); + if (p == std::string::npos) { std::fprintf(stderr, "fatal: %s.%s\n", key.c_str(), sub); std::exit(2); } + p += needle.size(); + size_t q = s.find_first_of(",}", p); + return std::stoi(s.substr(p, q - p)); + }; + + const int V_DIM = parse_dict("video", "dim"); + const int V_H = parse_dict("video", "heads"); + const int V_HD = parse_dict("video", "d_head"); + const int V_CTX = parse_dict("video", "ctx_dim"); + const int T_VIDEO = parse_dict("video", "T"); + const int S_VIDEO = parse_dict("video", "S"); + + const int A_DIM = parse_dict("audio", "dim"); + const int A_H = parse_dict("audio", "heads"); + const int A_HD = parse_dict("audio", "d_head"); + const int A_CTX = parse_dict("audio", "ctx_dim"); + const int T_AUDIO = parse_dict("audio", "T"); + const int S_AUDIO = parse_dict("audio", "S"); + + const bool cross_attention_adaln = parse_bool(m.config_str, "cross_attention_adaln"); + const bool apply_gated_attention = parse_bool(m.config_str, "apply_gated_attention"); + const float norm_eps = std::stof(m.config_str.at("norm_eps")); + const std::string rope_str = m.config_str.at("rope_type"); + const LTX::RopeType rope_type = + (rope_str.find("split") != std::string::npos) ? LTX::RopeType::SPLIT : LTX::RopeType::INTERLEAVED; + + std::printf("=== LTX-2 AV transformer block parity ===\n"); + std::printf("video: dim=%d heads=%d d_head=%d ctx=%d T=%d S=%d\n", + V_DIM, V_H, V_HD, V_CTX, T_VIDEO, S_VIDEO); + std::printf("audio: dim=%d heads=%d d_head=%d ctx=%d T=%d S=%d\n", + A_DIM, A_H, A_HD, A_CTX, T_AUDIO, S_AUDIO); + std::printf("flags: cross_attention_adaln=%d apply_gated_attention=%d rope=%s norm_eps=%g\n", + cross_attention_adaln, apply_gated_attention, rope_str.c_str(), norm_eps); + + auto backend = ggml_backend_cpu_init(); + AVBlockParityRunner runner(backend, + V_DIM, V_H, V_HD, + A_DIM, A_H, A_HD, + V_CTX, A_CTX, + cross_attention_adaln, apply_gated_attention, + rope_type, norm_eps); + runner.alloc_params_buffer(); + + // ---- Load every weight by name ---- + int loaded = 0, skipped = 0; + for (const auto& kv : m.weights_path) { + if (!runner.load_weight_into(kv.first, kv.second)) { + std::fprintf(stderr, "WARN: no slot for weight '%s' — skipping\n", kv.first.c_str()); + skipped++; + } else { + loaded++; + } + } + std::printf("loaded %d weights (skipped %d)\n", loaded, skipped); + if (skipped > 0) { + std::fprintf(stderr, "FAIL: some weights were unmapped — implementation gap\n"); + ggml_backend_free(backend); + return 1; + } + + // ---- Load all inputs ---- + auto path_of = [&](const char* k) -> std::string { + auto it = m.inputs_path.find(k); + if (it == m.inputs_path.end()) return ""; + return it->second; + }; + auto load_required = [&](const char* k, const std::vector& ne) -> sd::Tensor { + std::string p = path_of(k); + if (p.empty()) { + std::fprintf(stderr, "fatal: missing input %s\n", k); std::exit(2); + } + return load_raw(p, ne); + }; + auto load_optional = [&](const char* k, const std::vector& ne) -> sd::Tensor { + std::string p = path_of(k); + if (p.empty()) return sd::Tensor(); + return load_raw(p, ne); + }; + + const int B = 1; + const int num_main_mod = cross_attention_adaln ? 9 : 6; + + runner.vx = load_required("video__x", {V_DIM, T_VIDEO, B}); + runner.ax = load_required("audio__x", {A_DIM, T_AUDIO, B}); + runner.v_ctx = load_required("video__context", {V_CTX, S_VIDEO, B}); + runner.a_ctx = load_required("audio__context", {A_CTX, S_AUDIO, B}); + runner.v_mod = load_required("video__timesteps", {V_DIM, num_main_mod, B}); + runner.a_mod = load_required("audio__timesteps", {A_DIM, num_main_mod, B}); + + if (cross_attention_adaln) { + runner.v_pmod = load_required("video__prompt_timestep", {V_DIM, 2, B}); + runner.a_pmod = load_required("audio__prompt_timestep", {A_DIM, 2, B}); + } + runner.v_pe = build_pe_packed(path_of("video__pe_cos"), path_of("video__pe_sin"), V_DIM, T_VIDEO); + runner.a_pe = build_pe_packed(path_of("audio__pe_cos"), path_of("audio__pe_sin"), A_DIM, T_AUDIO); + // Cross-modal RoPE: inner_dim_cross == audio.heads * audio.d_head == A_DIM (both modalities). + runner.v_cpe = build_pe_packed(path_of("video__cross_pe_cos"), path_of("video__cross_pe_sin"), A_DIM, T_VIDEO); + runner.a_cpe = build_pe_packed(path_of("audio__cross_pe_cos"), path_of("audio__cross_pe_sin"), A_DIM, T_AUDIO); + runner.v_css = load_required("video__cross_scale_shift_timestep", {V_DIM, 4, B}); + runner.a_css = load_required("audio__cross_scale_shift_timestep", {A_DIM, 4, B}); + runner.v_cg = load_required("video__cross_gate_timestep", {V_DIM, 1, B}); + runner.a_cg = load_required("audio__cross_gate_timestep", {A_DIM, 1, B}); + + // ---- Run ---- + if (!runner.compute_and_capture(/*n_threads=*/1, T_VIDEO, V_DIM, T_AUDIO, A_DIM)) { + ggml_backend_free(backend); + return 1; + } + + // ---- Diff vs python outputs ---- + auto v_ref = load_raw(m.outputs_path.at("video__x_out"), {V_DIM, T_VIDEO, B}); + auto a_ref = load_raw(m.outputs_path.at("audio__x_out"), {A_DIM, T_AUDIO, B}); + + auto vs = diff_f32(runner.result_v.data(), v_ref.data(), runner.result_v.numel()); + auto as_= diff_f32(runner.result_a.data(), a_ref.data(), runner.result_a.numel()); + + const double tol_abs = 5e-5; // generous for a 1-block forward at fp32 + bool pass = (vs.max_abs < tol_abs) && (as_.max_abs < tol_abs); + + std::printf("\nvideo out: max_abs=%.3e mean_abs=%.3e max_rel=%.3e (n=%lld)\n", + vs.max_abs, vs.mean_abs, vs.max_rel, (long long)vs.n); + std::printf("audio out: max_abs=%.3e mean_abs=%.3e max_rel=%.3e (n=%lld)\n", + as_.max_abs, as_.mean_abs, as_.max_rel, (long long)as_.n); + std::printf("\n%s (tol_abs=%.0e)\n", pass ? "PASS" : "FAIL", tol_abs); + + ggml_backend_free(backend); + return pass ? 0 : 1; +} diff --git a/tests/ltx_parity/test_av_block_smoke.cpp b/tests/ltx_parity/test_av_block_smoke.cpp new file mode 100644 index 000000000..4d0fcb2c9 --- /dev/null +++ b/tests/ltx_parity/test_av_block_smoke.cpp @@ -0,0 +1,206 @@ +// Structural smoke test for LTX-2 audio-video transformer block. +// +// Goal: exercise the new LTXTransformerBlock::forward_av path end-to-end with +// synthetic random weights and inputs, on the CPU backend, and verify that +// (a) all params allocate correctly, (b) the graph builds, (c) compute runs +// without ggml asserts, (d) outputs are finite and shaped as expected. +// +// This is NOT a numerical-parity test — see test_av_block_parity.cpp for that +// (planned: requires dump_av_block.py to capture python reference tensors). +// +// Tiny config: video_dim=64, audio_dim=32, 2 video tokens, 2 audio tokens, +// context length 4. cross_attention_adaln=false, apply_gated_attention=false, +// rope_type=INTERLEAVED. B=1. + +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +// Minimal runner that hosts a single LTXTransformerBlock with audio_dim>0 and +// exposes forward_av via build_graph. Inputs are read from struct fields so +// the test can wire random data per-call. +struct AVBlockRunner : public GGMLRunner { + int64_t video_dim, audio_dim; + int v_heads, v_head_dim; + int a_heads, a_head_dim; + int64_t v_ctx_dim, a_ctx_dim; + bool cross_attention_adaln = false; + LTX::LTXTransformerBlock block; + + // Per-compute inputs. + sd::Tensor vx_t, ax_t; // [dim, L, B] + sd::Tensor v_ctx_t, a_ctx_t; // [ctx_dim, L_ctx, B] + sd::Tensor v_mod_t, a_mod_t; // [dim, num_mod, B] + sd::Tensor v_pe_t, a_pe_t; // [inner_dim, L, 2] + sd::Tensor v_cross_pe_t, a_cross_pe_t; + sd::Tensor v_css_t, a_css_t; // [dim, 4, B] + sd::Tensor v_cg_t, a_cg_t; // [dim, 1, B] + + AVBlockRunner(ggml_backend_t backend, + int64_t v_dim, int v_h, int v_hd, + int64_t a_dim, int a_h, int a_hd, + int64_t v_ctx, int64_t a_ctx) + : GGMLRunner(backend, /*offload_params_to_cpu=*/false), + video_dim(v_dim), audio_dim(a_dim), + v_heads(v_h), v_head_dim(v_hd), a_heads(a_h), a_head_dim(a_hd), + v_ctx_dim(v_ctx), a_ctx_dim(a_ctx), + block(v_dim, v_h, v_hd, v_ctx, + /*cross_attention_adaln=*/false, + /*apply_gated_attention=*/false, + /*norm_eps=*/1e-6f, + LTX::RopeType::INTERLEAVED, + a_dim, a_h, a_hd, a_ctx) { + block.init(params_ctx, /*tensor_storage_map=*/{}, /*prefix=*/""); + } + + std::string get_desc() override { return "AVBlockRunner"; } + + ggml_cgraph* build_graph() { + auto gf = new_graph_custom(LTX::LTX_GRAPH_SIZE); + + ggml_tensor* vx = make_input(vx_t); + ggml_tensor* ax = make_input(ax_t); + ggml_tensor* v_ctx = make_input(v_ctx_t); + ggml_tensor* a_ctx = make_input(a_ctx_t); + ggml_tensor* v_mod = make_input(v_mod_t); + ggml_tensor* a_mod = make_input(a_mod_t); + ggml_tensor* v_pe = make_input(v_pe_t); + ggml_tensor* a_pe = make_input(a_pe_t); + ggml_tensor* v_cpe = make_input(v_cross_pe_t); + ggml_tensor* a_cpe = make_input(a_cross_pe_t); + ggml_tensor* v_css = make_input(v_css_t); + ggml_tensor* a_css = make_input(a_css_t); + ggml_tensor* v_cg = make_input(v_cg_t); + ggml_tensor* a_cg = make_input(a_cg_t); + + LTX::LTX2AVModalityArgs vargs; + vargs.x = vx; vargs.context = v_ctx; vargs.modulation = v_mod; + vargs.pe = v_pe; vargs.cross_pe = v_cpe; + vargs.cross_scale_shift_modulation = v_css; + vargs.cross_gate_modulation = v_cg; + + LTX::LTX2AVModalityArgs aargs; + aargs.x = ax; aargs.context = a_ctx; aargs.modulation = a_mod; + aargs.pe = a_pe; aargs.cross_pe = a_cpe; + aargs.cross_scale_shift_modulation = a_css; + aargs.cross_gate_modulation = a_cg; + + auto runner_ctx = get_context(); + auto outs = block.forward_av(&runner_ctx, vargs, aargs); + + // Build a combined output by concatenating along the token dim. For the + // smoke test we just return the audio output (the second piece). + ggml_build_forward_expand(gf, outs.first); + ggml_build_forward_expand(gf, outs.second); + // Mark vx_out as the named final result so compute() can pull it. + return gf; + } + + // Returns concatenated stats: video_max, audio_max, video_finite, audio_finite. + bool run(int n_threads) { + auto get_graph = [this]() { return build_graph(); }; + // We don't use compute<>'s return — outputs are inspected via direct + // ggml_backend_tensor_get from the named result tensor. For simplicity + // we just verify compute does not abort. + auto out = compute(get_graph, n_threads, /*free_compute_buffer_immediately=*/true, + /*no_return=*/true); + return true; + } + + // Fill all params with deterministic uniform [-0.05, 0.05] noise. + void randomize_params(uint32_t seed) { + std::map all; + block.get_param_tensors(all, /*prefix=*/""); + std::mt19937 rng(seed); + std::uniform_real_distribution dist(-0.05f, 0.05f); + for (auto& kv : all) { + ggml_tensor* t = kv.second; + std::vector buf(ggml_nelements(t)); + for (auto& v : buf) v = dist(rng); + ggml_backend_tensor_set(t, buf.data(), 0, ggml_nbytes(t)); + } + } +}; + +template +sd::Tensor rand_tensor(const std::vector& shape, Dist&& d, RNG& rng) { + sd::Tensor t(shape); + for (int64_t i = 0; i < t.numel(); ++i) { + t.data()[i] = d(rng); + } + return t; +} + +bool all_finite(const sd::Tensor& t) { + for (int64_t i = 0; i < t.numel(); ++i) { + if (!std::isfinite(t.data()[i])) return false; + } + return true; +} + +} // namespace + +int main() { + // Tiny dims: B=1, video L=2, audio L=2, context_len=4. + constexpr int64_t V_DIM = 64; // video.inner_dim — must be num_heads*head_dim + constexpr int V_H = 4, V_HD = 16; // 4*16=64 ✓ + constexpr int64_t A_DIM = 32; // audio.inner_dim + constexpr int A_H = 2, A_HD = 16; // 2*16=32 ✓ + constexpr int64_t V_CTX = V_DIM; // skip caption_projection — context dim == video.dim + constexpr int64_t A_CTX = A_DIM; + constexpr int64_t L_V = 2, L_A = 2, L_CTX = 4, B = 1; + + auto backend = ggml_backend_cpu_init(); + AVBlockRunner runner(backend, V_DIM, V_H, V_HD, A_DIM, A_H, A_HD, V_CTX, A_CTX); + runner.alloc_params_buffer(); + + runner.randomize_params(0xBEEF); + + std::mt19937 rng(0x42); + std::uniform_real_distribution nrm(-0.5f, 0.5f); + + // num_mod = 6 for cross_attention_adaln=false + runner.vx_t = rand_tensor({V_DIM, L_V, B}, nrm, rng); + runner.ax_t = rand_tensor({A_DIM, L_A, B}, nrm, rng); + runner.v_ctx_t = rand_tensor({V_CTX, L_CTX, B}, nrm, rng); + runner.a_ctx_t = rand_tensor({A_CTX, L_CTX, B}, nrm, rng); + runner.v_mod_t = rand_tensor({V_DIM, 6, B}, nrm, rng); + runner.a_mod_t = rand_tensor({A_DIM, 6, B}, nrm, rng); + runner.v_pe_t = rand_tensor({V_DIM, L_V, 2}, nrm, rng); + runner.a_pe_t = rand_tensor({A_DIM, L_A, 2}, nrm, rng); + // Cross-modal attention inner_dim = audio_heads * audio_head_dim (= A_DIM here), + // applied to queries on either side. So both cross PEs are sized A_DIM. + runner.v_cross_pe_t = rand_tensor({A_DIM, L_V, 2}, nrm, rng); + runner.a_cross_pe_t = rand_tensor({A_DIM, L_A, 2}, nrm, rng); + runner.v_css_t = rand_tensor({V_DIM, 4, B}, nrm, rng); + runner.a_css_t = rand_tensor({A_DIM, 4, B}, nrm, rng); + runner.v_cg_t = rand_tensor({V_DIM, 1, B}, nrm, rng); + runner.a_cg_t = rand_tensor({A_DIM, 1, B}, nrm, rng); + + std::printf("=== LTX-2 AV transformer block smoke test ===\n"); + std::printf("video: dim=%lld heads=%d head_dim=%d L=%lld ctx=[%lld,%lld]\n", + (long long)V_DIM, V_H, V_HD, (long long)L_V, + (long long)V_CTX, (long long)L_CTX); + std::printf("audio: dim=%lld heads=%d head_dim=%d L=%lld ctx=[%lld,%lld]\n", + (long long)A_DIM, A_H, A_HD, (long long)L_A, + (long long)A_CTX, (long long)L_CTX); + + if (!runner.run(/*n_threads=*/1)) { + std::fprintf(stderr, "FAIL: run() returned false\n"); + return 1; + } + std::printf("PASS: forward_av compute completed without abort\n"); + + ggml_backend_free(backend); + return 0; +} diff --git a/tests/ltx_parity/test_av_model_parity.cpp b/tests/ltx_parity/test_av_model_parity.cpp new file mode 100644 index 000000000..353fd0223 --- /dev/null +++ b/tests/ltx_parity/test_av_model_parity.cpp @@ -0,0 +1,331 @@ +// LTX-2 AV MODEL parity test (the full LTXModel, not just one block). +// +// Loads weights+inputs+outputs dumped by tests/ltx_parity/dump_av_model.py and +// runs LTXModel::forward_av on the same inputs, then diffs against the python +// reference. Exercises: +// - audio_patchify_proj + audio_adaln_single + audio_prompt_adaln_single +// - 4 cross-modal AdaLN modules (av_ca_*_adaln_single) +// - num_layers transformer blocks via forward_av +// - both output heads (video proj_out + audio_proj_out) + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +const std::string REF_DIR = "/tmp/ltx_av_model_ref"; + +sd::Tensor load_raw(const std::string& path, const std::vector& ne) { + sd::Tensor t(ne); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); std::exit(2); } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { std::fprintf(stderr, "fatal: short read on %s\n", path.c_str()); std::exit(2); } + return t; +} + +struct DiffStats { double max_abs=0, mean_abs=0, max_rel=0; int64_t n=0; }; + +DiffStats diff_f32(const float* a, const float* b, int64_t n) { + DiffStats s; s.n = n; double sum = 0.0; + for (int64_t i = 0; i < n; ++i) { + double abs_err = std::fabs(double(a[i]) - double(b[i])); + double rel_err = abs_err / (std::fabs(double(b[i])) + 1e-8); + if (abs_err > s.max_abs) s.max_abs = abs_err; + if (rel_err > s.max_rel) s.max_rel = rel_err; + sum += abs_err; + } + s.mean_abs = sum / std::max(1, n); + return s; +} + +sd::Tensor build_pe(const std::string& cos_path, const std::string& sin_path, + int64_t inner_dim, int64_t T) { + sd::Tensor pe({inner_dim, T, 2}); + auto cos = load_raw(cos_path, {inner_dim, T, 1}); + auto sin = load_raw(sin_path, {inner_dim, T, 1}); + std::memcpy(pe.data(), cos.data(), cos.numel() * sizeof(float)); + std::memcpy(pe.data() + cos.numel(), sin.data(), sin.numel() * sizeof(float)); + return pe; +} + +struct Manifest { + std::map weights_path; + std::map inputs_path; + std::map outputs_path; + std::map config_str; +}; + +Manifest read_manifest(const std::string& dir) { + std::string cmd = "/home/ilintar/venv/bin/python -c '" + "import json,sys\n" + "m=json.load(open(\"" + dir + "/manifest.json\"))\n" + "for k,v in m[\"config\"].items(): print(\"C:\"+k+\":\"+str(v))\n" + "for sec in [\"weights\",\"inputs\",\"outputs\"]:\n" + " for n,d in m[sec].items():\n" + " print(sec[0].upper()+\":\"+n+\":\"+d[\"path\"])\n" + "' 2>/dev/null"; + FILE* p = popen(cmd.c_str(), "r"); + if (!p) { std::fprintf(stderr, "fatal: popen failed\n"); std::exit(2); } + Manifest m; + char line[4096]; + while (std::fgets(line, sizeof(line), p)) { + std::string s(line); + if (!s.empty() && s.back() == '\n') s.pop_back(); + if (s.size() < 3 || s[1] != ':') continue; + char tag = s[0]; + std::string rest = s.substr(2); + size_t c = rest.find(':'); + if (c == std::string::npos) continue; + std::string name = rest.substr(0, c); + std::string val = rest.substr(c + 1); + if (tag == 'C') m.config_str[name] = val; + else if (tag == 'W') m.weights_path[name] = dir + "/" + val; + else if (tag == 'I') m.inputs_path[name] = dir + "/" + val; + else if (tag == 'O') m.outputs_path[name] = dir + "/" + val; + } + pclose(p); + return m; +} + +int parse_dict(const std::map& cfg, const std::string& key, const char* sub) { + auto it = cfg.find(key); + if (it == cfg.end()) { std::fprintf(stderr, "fatal: %s\n", key.c_str()); std::exit(2); } + std::string s = it->second; + std::string needle = std::string("'") + sub + "': "; + size_t p = s.find(needle); + if (p == std::string::npos) { std::fprintf(stderr, "fatal: %s.%s\n", key.c_str(), sub); std::exit(2); } + p += needle.size(); + size_t q = s.find_first_of(",}", p); + return std::stoi(s.substr(p, q - p)); +} +bool parse_bool(const std::map& cfg, const std::string& k) { + auto it = cfg.find(k); + if (it == cfg.end()) return false; + return it->second == "True" || it->second == "true" || it->second == "1"; +} + +struct AVModelParityRunner : public GGMLRunner { + LTX::LTXModel model; + + sd::Tensor v_latent, a_latent; + sd::Tensor v_context, a_context; + sd::Tensor v_t_self, a_t_self; + sd::Tensor v_t_prompt_self, a_t_prompt_self; + sd::Tensor v_t_cross_ss, a_t_cross_ss; + sd::Tensor v_t_cross_gate, a_t_cross_gate; + sd::Tensor v_pe, a_pe, v_cpe, a_cpe; + + sd::Tensor result_v, result_a; + + AVModelParityRunner(ggml_backend_t backend, LTX::LTXParams params) + : GGMLRunner(backend, /*offload_params_to_cpu=*/false), model(params) { + model.init(params_ctx, /*tensor_storage_map=*/{}, /*prefix=*/""); + } + + std::string get_desc() override { return "AVModelParityRunner"; } + + ggml_cgraph* build_graph() { + auto gf = new_graph_custom(LTX::LTX_GRAPH_SIZE); + ggml_tensor* g_v_latent = make_input(v_latent); + ggml_tensor* g_a_latent = make_input(a_latent); + ggml_tensor* g_v_ctx = make_input(v_context); + ggml_tensor* g_a_ctx = make_input(a_context); + ggml_tensor* g_v_t = make_input(v_t_self); + ggml_tensor* g_a_t = make_input(a_t_self); + ggml_tensor* g_v_t_p = v_t_prompt_self.empty() ? nullptr : make_input(v_t_prompt_self); + ggml_tensor* g_a_t_p = a_t_prompt_self.empty() ? nullptr : make_input(a_t_prompt_self); + ggml_tensor* g_v_t_xss = make_input(v_t_cross_ss); + ggml_tensor* g_a_t_xss = make_input(a_t_cross_ss); + ggml_tensor* g_v_t_xg = make_input(v_t_cross_gate); + ggml_tensor* g_a_t_xg = make_input(a_t_cross_gate); + ggml_tensor* g_v_pe = make_input(v_pe); + ggml_tensor* g_a_pe = make_input(a_pe); + ggml_tensor* g_v_cpe = make_input(v_cpe); + ggml_tensor* g_a_cpe = make_input(a_cpe); + + auto rctx = get_context(); + auto outs = model.forward_av(&rctx, + g_v_latent, g_a_latent, + g_v_t, g_a_t, + g_v_t_p, g_a_t_p, + g_v_t_xss, g_a_t_xss, + g_v_t_xg, g_a_t_xg, + g_v_ctx, g_a_ctx, + g_v_pe, g_a_pe, + g_v_cpe, g_a_cpe, + nullptr, nullptr); + ggml_set_name(outs.first, "av_model_video_out"); + ggml_set_name(outs.second, "av_model_audio_out"); + ggml_build_forward_expand(gf, outs.first); + ggml_build_forward_expand(gf, outs.second); + cache("av_model_video_out", outs.first); + cache("av_model_audio_out", outs.second); + return gf; + } + + bool compute_and_capture(int n_threads, + int64_t v_t, int64_t v_out_dim, + int64_t a_t, int64_t a_out_dim) { + auto gg = [this]() { return build_graph(); }; + compute(gg, n_threads, /*free_compute_buffer_immediately=*/false, + /*no_return=*/true); + ggml_tensor* tv = get_cache_tensor_by_name("av_model_video_out"); + ggml_tensor* ta = get_cache_tensor_by_name("av_model_audio_out"); + if (!tv || !ta) { std::fprintf(stderr, "fatal: missing output tensors\n"); return false; } + result_v = sd::Tensor({v_out_dim, v_t, 1}); + result_a = sd::Tensor({a_out_dim, a_t, 1}); + ggml_backend_tensor_get(tv, result_v.data(), 0, ggml_nbytes(tv)); + ggml_backend_tensor_get(ta, result_a.data(), 0, ggml_nbytes(ta)); + return true; + } + + bool load_weight_into(const std::string& name, const std::string& path) { + std::map all; + model.get_param_tensors(all, /*prefix=*/""); + auto it = all.find(name); + if (it == all.end()) return false; + ggml_tensor* t = it->second; + std::vector buf(ggml_nelements(t)); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) return false; + f.read(reinterpret_cast(buf.data()), + static_cast(buf.size() * sizeof(float))); + ggml_backend_tensor_set(t, buf.data(), 0, ggml_nbytes(t)); + return true; + } +}; + +} // namespace + + +int main() { + Manifest m = read_manifest(REF_DIR); + if (m.config_str.empty()) { + std::fprintf(stderr, "fatal: empty manifest. Run dump_av_model.py first.\n"); + return 2; + } + + LTX::LTXParams params; + params.in_channels = parse_dict(m.config_str, "video", "in_channels"); + params.out_channels = parse_dict(m.config_str, "video", "out_channels"); + params.inner_dim = parse_dict(m.config_str, "video", "dim"); + params.num_heads = parse_dict(m.config_str, "video", "heads"); + params.head_dim = parse_dict(m.config_str, "video", "d_head"); + params.cross_attention_dim = parse_dict(m.config_str, "video", "ctx_dim"); + params.num_layers = std::stoi(m.config_str.at("num_layers")); + params.cross_attention_adaln = parse_bool(m.config_str, "cross_attention_adaln"); + params.apply_gated_attention = parse_bool(m.config_str, "apply_gated_attention"); + params.norm_eps = std::stof(m.config_str.at("norm_eps")); + { + const std::string& rs = m.config_str.at("rope_type"); + params.rope_type = (rs.find("split") != std::string::npos) ? LTX::RopeType::SPLIT + : LTX::RopeType::INTERLEAVED; + } + + params.has_audio_video = true; + params.audio_in_channels = parse_dict(m.config_str, "audio", "in_channels"); + params.audio_out_channels = parse_dict(m.config_str, "audio", "out_channels"); + params.audio_inner_dim = parse_dict(m.config_str, "audio", "dim"); + params.audio_num_heads = parse_dict(m.config_str, "audio", "heads"); + params.audio_head_dim = parse_dict(m.config_str, "audio", "d_head"); + params.audio_cross_attention_dim = parse_dict(m.config_str, "audio", "ctx_dim"); + + const int T_VIDEO = parse_dict(m.config_str, "video", "T"); + const int T_AUDIO = parse_dict(m.config_str, "audio", "T"); + const int S_VIDEO = parse_dict(m.config_str, "video", "S"); + const int S_AUDIO = parse_dict(m.config_str, "audio", "S"); + + std::printf("=== LTX-2 AV MODEL parity ===\n"); + std::printf("video: dim=%lld heads=%d d_head=%d in=%lld out=%lld T=%d S=%d\n", + (long long)params.inner_dim, params.num_heads, params.head_dim, + (long long)params.in_channels, (long long)params.out_channels, + T_VIDEO, S_VIDEO); + std::printf("audio: dim=%lld heads=%d d_head=%d in=%lld out=%lld T=%d S=%d\n", + (long long)params.audio_inner_dim, params.audio_num_heads, params.audio_head_dim, + (long long)params.audio_in_channels, (long long)params.audio_out_channels, + T_AUDIO, S_AUDIO); + std::printf("flags: ca_adaln=%d gated=%d num_layers=%d\n", + params.cross_attention_adaln, params.apply_gated_attention, params.num_layers); + + auto backend = ggml_backend_cpu_init(); + AVModelParityRunner runner(backend, params); + runner.alloc_params_buffer(); + + int loaded = 0, skipped = 0; + for (const auto& kv : m.weights_path) { + if (!runner.load_weight_into(kv.first, kv.second)) { + std::fprintf(stderr, "WARN: no slot for weight '%s' — skipping\n", kv.first.c_str()); + skipped++; + } else loaded++; + } + std::printf("loaded %d weights (skipped %d)\n", loaded, skipped); + if (skipped > 0) { + std::fprintf(stderr, "FAIL: some weights unmapped — implementation gap\n"); + ggml_backend_free(backend); + return 1; + } + + auto path = [&](const char* k) { return m.inputs_path.at(k); }; + runner.v_latent = load_raw(path("video__latent"), {params.in_channels, T_VIDEO, 1}); + runner.a_latent = load_raw(path("audio__latent"), {params.audio_in_channels, T_AUDIO, 1}); + runner.v_context = load_raw(path("video__context"), {params.cross_attention_dim, S_VIDEO, 1}); + runner.a_context = load_raw(path("audio__context"), {params.audio_cross_attention_dim, S_AUDIO, 1}); + runner.v_t_self = load_raw(path("video__t_self"), {1}); + runner.a_t_self = load_raw(path("audio__t_self"), {1}); + if (params.cross_attention_adaln) { + runner.v_t_prompt_self = load_raw(path("video__t_prompt_self"), {1}); + runner.a_t_prompt_self = load_raw(path("audio__t_prompt_self"), {1}); + } + runner.v_t_cross_ss = load_raw(path("video__t_cross_ss"), {1}); + runner.a_t_cross_ss = load_raw(path("audio__t_cross_ss"), {1}); + runner.v_t_cross_gate = load_raw(path("video__t_cross_gate"), {1}); + runner.a_t_cross_gate = load_raw(path("audio__t_cross_gate"), {1}); + + runner.v_pe = build_pe(path("video__pe_cos"), path("video__pe_sin"), params.inner_dim, T_VIDEO); + runner.a_pe = build_pe(path("audio__pe_cos"), path("audio__pe_sin"), params.audio_inner_dim, T_AUDIO); + runner.v_cpe = build_pe(path("video__cross_pe_cos"), path("video__cross_pe_sin"), + params.audio_inner_dim, T_VIDEO); + runner.a_cpe = build_pe(path("audio__cross_pe_cos"), path("audio__cross_pe_sin"), + params.audio_inner_dim, T_AUDIO); + + if (!runner.compute_and_capture(/*n_threads=*/1, + T_VIDEO, params.out_channels, + T_AUDIO, params.audio_out_channels)) { + ggml_backend_free(backend); + return 1; + } + + auto v_ref = load_raw(m.outputs_path.at("video__x_out"), {params.out_channels, T_VIDEO, 1}); + auto a_ref = load_raw(m.outputs_path.at("audio__x_out"), {params.audio_out_channels, T_AUDIO, 1}); + + auto vs = diff_f32(runner.result_v.data(), v_ref.data(), runner.result_v.numel()); + auto as_ = diff_f32(runner.result_a.data(), a_ref.data(), runner.result_a.numel()); + + const double tol_abs = 1e-4; // 2-block model accumulates more drift than 1-block + bool pass = (vs.max_abs < tol_abs) && (as_.max_abs < tol_abs); + + std::printf("\nvideo out: max_abs=%.3e mean_abs=%.3e max_rel=%.3e (n=%lld)\n", + vs.max_abs, vs.mean_abs, vs.max_rel, (long long)vs.n); + std::printf("audio out: max_abs=%.3e mean_abs=%.3e max_rel=%.3e (n=%lld)\n", + as_.max_abs, as_.mean_abs, as_.max_rel, (long long)as_.n); + std::printf("\n%s (tol_abs=%.0e)\n", pass ? "PASS" : "FAIL", tol_abs); + + ggml_backend_free(backend); + return pass ? 0 : 1; +} diff --git a/tests/ltx_parity/test_connector_parity.cpp b/tests/ltx_parity/test_connector_parity.cpp new file mode 100644 index 000000000..f364dae5c --- /dev/null +++ b/tests/ltx_parity/test_connector_parity.cpp @@ -0,0 +1,297 @@ +// LTX-2 text connector parity test (V1 / 19B). +// +// Loads /tmp/connector_ref/{state_dict.safetensors, tensors/*.bin} produced by +// dump_connector.py, runs: +// 1. CPU feature_extractor_normalize on the stacked input +// 2. LTX2ConnectorRunner::compute through each probe stage +// and diffs against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx_connector.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/connector_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // Tiny config (must match dump_connector.py). + const int64_t B = 1; + const int64_t T = 8; + const int NUM_HEADS = 2; + const int HEAD_DIM = 32; + const int64_t D = NUM_HEADS * HEAD_DIM; // connector inner_dim = 64 + const int64_t L = 5; // stacked layers + const int64_t FLAT_DIM = D * L; // 320 + const int NUM_LAYERS = 2; + const int NUM_REGISTERS = 4; + const int64_t CAPTION_CHANNELS = D; // 64 + const int64_t CAPTION_HIDDEN = 128; + const int64_t CAPTION_OUT = 128; + const float THETA = 10000.0f; + const std::vector MAX_POS = {1}; + + // --- 1. Load state dict. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct runner. + ggml_backend_t backend = ggml_backend_cpu_init(); + LTXConnector::LTX2ConnectorRunner runner( + backend, /*offload_params_to_cpu=*/false, + FLAT_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS, NUM_REGISTERS, + CAPTION_CHANNELS, CAPTION_HIDDEN, CAPTION_OUT, + THETA, MAX_POS, tsm, /*prefix=*/""); + + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, ""); + std::printf("[load] %zu param tensors…\n", param_tensors.size()); + + // Diagnose any missing tensors. + int missing_shown = 0; + std::set tsm_keys; + for (const auto& kv : tsm) tsm_keys.insert(kv.first); + for (const auto& pt : param_tensors) { + if (tsm_keys.find(pt.first) == tsm_keys.end()) { + if (missing_shown < 5) { + std::printf("[load] missing in file: %s\n", pt.first.c_str()); + missing_shown++; + } + } + } + + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed\n"); + return 1; + } + + // --- 3. Load stacked input (ref layout: [B, T, D, L]). + auto stacked_in = load_raw_bin(ref_dir + "/tensors/stacked_in.bin", {L, D, T, B}); + + // --- 4. CPU-side feature extractor normalization. + std::vector seq_lens(B, static_cast(T)); // all-ones mask + sd::Tensor normed({FLAT_DIM, T, B}); + LTXConnector::feature_extractor_normalize( + stacked_in.data(), seq_lens.data(), normed.data(), + static_cast(B), static_cast(T), static_cast(D), static_cast(L), + "left", 1e-6f); + + // --- 5. Run each probe stage and diff. + struct Probe { + int stage; + const char* name; + std::vector shape; // ne order (innermost first) + float tol_max_abs; + float tol_mean_abs; + }; +// Tolerances reflect: (1) fp16 K/V cast in ggml_ext_attention_ext (~1e-3 per + // attention layer), (2) residual fp32 cos/sin divergence between torch and + // libm at the tail of the freq grid (~6e-3 PE max diff → ~1e-3 per q/k + // rotation). Two attention layers → ~2-3e-3 max_abs cap end-to-end. + const Probe probes[] = { + {0, "feat_ext_out", {D, T, B}, 1e-4f, 5e-5f}, + {1, "conn_block_0_out", {D, T, B}, 3e-3f, 5e-4f}, + {2, "conn_block_1_out", {D, T, B}, 4e-3f, 1e-3f}, + {3, "conn_final_out", {D, T, B}, 3e-3f, 5e-4f}, + {4, "caption_proj_out", {CAPTION_OUT, T, B}, 3e-3f, 1e-3f}, + }; + + bool all_pass = true; + std::printf("\n=== LTX-2 Connector parity ===\n"); + std::printf("%-20s %11s %11s %11s %s\n", "tag", "max_abs", "mean_abs", "max_rel", "result"); + + for (const auto& p : probes) { + auto out = runner.compute(/*n_threads=*/1, normed, p.stage); + auto ref = load_raw_bin(ref_dir + "/tensors/" + p.name + ".bin", p.shape); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "[%s] size mismatch got=%ld want=%ld\n", + p.name, out.numel(), ref.numel()); + return 1; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + bool pass = s.max_abs < p.tol_max_abs && s.mean_abs < p.tol_mean_abs; + std::printf(" %-18s %.3e %.3e %.3e %s\n", + p.name, s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + if (!pass && s.max_abs_idx >= 0) { + int64_t i = s.max_abs_idx; + std::printf(" max-diff @ idx=%ld: got=%+.6f want=%+.6f diff=%+.6f\n", + i, out.data()[i], ref.data()[i], out.data()[i] - ref.data()[i]); + } + all_pass &= pass; + } + + std::printf("\n%s\n", all_pass ? "Connector parity: PASS" : "Connector parity: FAIL"); + + // ---------- Padded variant: T_REAL < NUM_REGISTERS ---------- + // This section exercises the learnable-register concat path in + // LTX2ConnectorRunner::build_graph that the primary run above skips (there + // T=8 > NUM_REGISTERS=4). The reference is dumped by + // `CONNECTOR_VARIANT=padded dump_connector.py` with NUM_REGISTERS=8, + // SEQ_LEN=8 and a left-padded attention_mask making only the last 3 tokens + // real. Python runs the full pipeline (feature_extractor → replace_padded + // → connector); C++ feeds only the 3 real tokens (slide-to-front done in + // the conditioner on the production path) and the runner's concat-with- + // registers path must reconstruct the same 8-token sequence internally. + const std::string padded_dir = "/tmp/connector_ref_padded"; + std::ifstream padded_check(padded_dir + "/state_dict.safetensors"); + if (!padded_check.is_open()) { + std::printf("\n[padded] %s not found — skip. Run " + "`CONNECTOR_VARIANT=padded dump_connector.py` to enable.\n", + padded_dir.c_str()); + return all_pass ? 0 : 3; + } + padded_check.close(); + + std::printf("\n=== LTX-2 Connector parity (padded: T_real=3 < num_reg=8) ===\n"); + + const int64_t PAD_T_REAL = 3; + const int64_t PAD_T_FULL = 8; + const int NUM_REGISTERS_PAD = 8; + + ModelLoader pad_loader; + if (!pad_loader.init_from_file(padded_dir + "/state_dict.safetensors")) { + std::fprintf(stderr, "fatal: padded init_from_file failed\n"); + return 1; + } + const auto& pad_tsm = pad_loader.get_tensor_storage_map(); + std::printf("[padded state_dict] loaded %zu tensors\n", pad_tsm.size()); + + LTXConnector::LTX2ConnectorRunner pad_runner( + backend, /*offload_params_to_cpu=*/false, + FLAT_DIM, NUM_HEADS, HEAD_DIM, NUM_LAYERS, NUM_REGISTERS_PAD, + CAPTION_CHANNELS, CAPTION_HIDDEN, CAPTION_OUT, + THETA, MAX_POS, pad_tsm, /*prefix=*/""); + pad_runner.alloc_params_buffer(); + + std::map pad_params; + pad_runner.get_param_tensors(pad_params, ""); + if (!pad_loader.load_tensors(pad_params)) { + std::fprintf(stderr, "fatal: padded load_tensors failed\n"); + return 1; + } + + // Load the full padded stacked input (padded positions at the START), then + // slice to only the T_REAL real tokens at the tail — this is what the + // production conditioner passes to the connector runner after sliding the + // real rows to the front. + auto pad_stacked_full = load_raw_bin(padded_dir + "/tensors/stacked_in.bin", + {L, D, PAD_T_FULL, B}); + // Ref layout [B, T, D, L] → ggml ne [L, D, T, B]. Real tokens occupy + // indices [PAD_T_FULL - PAD_T_REAL .. PAD_T_FULL) along axis T (ne[2]). + sd::Tensor pad_stacked_real({L, D, PAD_T_REAL, B}); + for (int64_t b = 0; b < B; ++b) { + for (int64_t t = 0; t < PAD_T_REAL; ++t) { + for (int64_t d = 0; d < D; ++d) { + for (int64_t l = 0; l < L; ++l) { + int64_t src = ((b * PAD_T_FULL + (PAD_T_FULL - PAD_T_REAL + t)) * D + d) * L + l; + int64_t dst = ((b * PAD_T_REAL + t) * D + d) * L + l; + pad_stacked_real.data()[dst] = pad_stacked_full.data()[src]; + } + } + } + } + + // CPU normalize the real-only stacked input (no padding). + std::vector pad_seq_lens(B, static_cast(PAD_T_REAL)); + sd::Tensor pad_normed({FLAT_DIM, PAD_T_REAL, B}); + LTXConnector::feature_extractor_normalize( + pad_stacked_real.data(), pad_seq_lens.data(), pad_normed.data(), + static_cast(B), static_cast(PAD_T_REAL), static_cast(D), static_cast(L), + "left", 1e-6f); + + // Connector should internally concat learnable_registers[T_real:num_reg] + // → output shape at the final stage is [D, num_reg, B]. + bool pad_pass = true; + const Probe pad_probes[] = { + // Feature-extractor output is just the T_REAL real tokens (shape + // [D, T_REAL, B]); Python's feat_ext_out covers T_FULL padded and we + // only check the real-token tail. + {3, "conn_final_out", {D, PAD_T_FULL, B}, 6e-3f, 2e-3f}, + {4, "caption_proj_out", {CAPTION_OUT, PAD_T_FULL, B}, 6e-3f, 2e-3f}, + }; + + for (const auto& p : pad_probes) { + auto out = pad_runner.compute(/*n_threads=*/1, pad_normed, p.stage); + auto ref = load_raw_bin(padded_dir + "/tensors/" + p.name + ".bin", p.shape); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "[padded %s] size mismatch got=%ld want=%ld\n", + p.name, out.numel(), ref.numel()); + return 1; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + bool pass = s.max_abs < p.tol_max_abs && s.mean_abs < p.tol_mean_abs; + std::printf(" %-18s %.3e %.3e %.3e %s\n", + p.name, s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + if (!pass && s.max_abs_idx >= 0) { + int64_t i = s.max_abs_idx; + std::printf(" max-diff @ idx=%ld: got=%+.6f want=%+.6f diff=%+.6f\n", + i, out.data()[i], ref.data()[i], out.data()[i] - ref.data()[i]); + } + pad_pass &= pass; + } + + std::printf("\n%s\n", pad_pass ? "Connector padded parity: PASS" : "Connector padded parity: FAIL"); + return (all_pass && pad_pass) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_cont_parity.cpp b/tests/ltx_parity/test_cont_parity.cpp new file mode 100644 index 000000000..05528ba5c --- /dev/null +++ b/tests/ltx_parity/test_cont_parity.cpp @@ -0,0 +1,129 @@ +// Standalone parity test: ggml_cont(ggml_permute(...)) on CPU vs CUDA. +// Loads the same byte-identical v_proj output onto both backends, applies the +// exact reshape/permute/cont chain Gemma uses, and diffs the resulting +// contiguous tensor. + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#include "ggml.h" + +static std::vector run(ggml_backend_t backend, + const std::vector& src_data, + int K, int H, int T, int N) { + struct ggml_init_params params = {}; + params.mem_size = 16 * 1024 * 1024; + params.mem_buffer = nullptr; + params.no_alloc = true; + ggml_context* ctx = ggml_init(params); + + // src layout matches Gemma: ggml_reshape_4d(v_proj_out, head_dim, num_kv_heads, n_token, N) + // = [K=head_dim, H=num_kv_heads, T=n_token, N=batch] + auto src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, K, H, T, N); + ggml_set_name(src, "src"); + + // Same chain as ggml_ext_attention_ext when v comes in 4D: + // v = ggml_ext_cont(ggml_permute(v, 1, 2, 0, 3)) → [N, n_kv_head, d_head, L_k] + // v = ggml_reshape_3d(v, L_k, d_head, n_kv_head*N) → [N*n_kv_head, d_head, L_k] + auto permuted = ggml_permute(ctx, src, 1, 2, 0, 3); + auto cont = ggml_cont(ctx, permuted); + ggml_set_name(cont, "cont"); + ggml_set_output(cont); + + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, cont); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (buf == nullptr) { + std::fprintf(stderr, "fatal: alloc_ctx_tensors failed\n"); + std::exit(1); + } + ggml_backend_tensor_set(src, src_data.data(), 0, ggml_nbytes(src)); + + ggml_gallocr_t gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(gallocr, gf)) { + std::fprintf(stderr, "fatal: alloc_graph failed\n"); + std::exit(1); + } + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "fatal: compute failed\n"); + std::exit(1); + } + + std::vector out(ggml_nelements(cont)); + ggml_backend_tensor_get(cont, out.data(), 0, ggml_nbytes(cont)); + + ggml_gallocr_free(gallocr); + ggml_backend_buffer_free(buf); + ggml_free(ctx); + return out; +} + +int main(int argc, char** argv) { + int K = 256, H = 8, T = 128, N = 1; + if (argc >= 5) { + K = std::atoi(argv[1]); + H = std::atoi(argv[2]); + T = std::atoi(argv[3]); + N = std::atoi(argv[4]); + } + std::printf("Shape src=[K=%d, H=%d, T=%d, N=%d]\n", K, H, T, N); + + std::mt19937 rng(42); + std::normal_distribution dist(0.0f, 50.0f); // match Gemma v_proj scale + std::vector src(K * H * T * N); + for (auto& x : src) x = dist(rng); + + std::printf("Running CPU forward...\n"); + ggml_backend_t cpu_backend = ggml_backend_cpu_init(); + auto cpu_out = run(cpu_backend, src, K, H, T, N); + ggml_backend_free(cpu_backend); + +#ifdef SD_USE_CUDA + int cuda_device = 1; + if (const char* d = std::getenv("SD_CUDA_DEVICE")) cuda_device = std::atoi(d); + std::printf("Running CUDA forward (device %d)...\n", cuda_device); + ggml_backend_t cuda_backend = ggml_backend_cuda_init(cuda_device); + if (!cuda_backend) { + std::fprintf(stderr, "fatal: CUDA init failed for device %d\n", cuda_device); + return 1; + } + auto cuda_out = run(cuda_backend, src, K, H, T, N); + ggml_backend_free(cuda_backend); +#endif + + if (cpu_out.size() != cuda_out.size()) { + std::fprintf(stderr, "fatal: output size mismatch\n"); + return 1; + } + + double max_abs = 0.0, sum_abs = 0.0, sum_cpu_abs = 0.0; + int argmax = 0; + int n_diff = 0; + for (size_t i = 0; i < cpu_out.size(); ++i) { + double diff = std::fabs((double) cpu_out[i] - (double) cuda_out[i]); + if (diff > 0) n_diff++; + if (diff > max_abs) { max_abs = diff; argmax = (int) i; } + sum_abs += diff; + sum_cpu_abs += std::fabs(cpu_out[i]); + } + std::printf("Diff: max=%.6e mean=%.6e cpu_mean_mag=%.6e n_diff=%d/%zu\n", + max_abs, sum_abs / cpu_out.size(), sum_cpu_abs / cpu_out.size(), + n_diff, cpu_out.size()); + std::printf("Argmax [idx %d]: CPU=%+.9e CUDA=%+.9e\n", argmax, cpu_out[argmax], cuda_out[argmax]); + std::printf("First 6 elements:\n"); + for (int i = 0; i < 6 && i < (int) cpu_out.size(); ++i) { + std::printf(" [%2d] CPU=%+.9e CUDA=%+.9e diff=%+.3e\n", + i, cpu_out[i], cuda_out[i], cuda_out[i] - cpu_out[i]); + } + return 0; +} diff --git a/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp b/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp new file mode 100644 index 000000000..808d7c8a7 --- /dev/null +++ b/tests/ltx_parity/test_gemma_cpu_vs_cuda.cpp @@ -0,0 +1,352 @@ +// Layer-0 CPU vs CUDA parity for Gemma 3. +// +// Loads the user's real Gemma GGUF twice — once with a CPU backend, once with +// a CUDA backend — runs compute_all_hidden_states on the same tokens with +// g_layer0_taps set, and diffs each intermediate. This lets us pinpoint which +// Gemma layer-0 op first diverges between CPU and CUDA without pulling in the +// DiT (which would push a 32 GB system into swap/OOM). +// +// Usage: +// sd-gemma-cpu-vs-cuda [cuda_device] + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif + +#include "llm.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + int64_t argmax = -1; +}; + +DiffStats diff_f32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum = 0.0; + for (int64_t i = 0; i < n; ++i) { + float d = std::fabs(a[i] - b[i]); + if (d > s.max_abs) { + s.max_abs = d; + s.argmax = i; + } + sum += d; + } + s.mean_abs = static_cast(sum / (n > 0 ? n : 1)); + return s; +} + +std::vector fetch(ggml_tensor* t) { + size_t n = ggml_nbytes(t); + std::vector out(n); + ggml_backend_tensor_get(t, out.data(), 0, n); + return out; +} + +struct TapDump { + std::string name; + std::vector ne; + ggml_type type; + std::vector data; +}; + +std::vector run_and_dump(const std::string& model_path, + ggml_backend_t backend, + const std::vector& tokens) { + ModelLoader loader; + if (!loader.init_from_file(model_path, "text_encoders.llm.")) { + std::fprintf(stderr, "fatal: init_from_file failed: %s\n", model_path.c_str()); + std::exit(1); + } + loader.convert_tensors_name(); + + // SD_GEMMA_FORCE_TYPE=f16|bf16|f32|q8_0|q4_0 — on-load retype so both + // backends take the same matmul path (avoids iq4_xs×q8_K-vs-q8_1 drift). + if (const char* t = std::getenv("SD_GEMMA_FORCE_TYPE")) { + std::string s = t; + ggml_type tgt = GGML_TYPE_F16; + if (s == "f32") tgt = GGML_TYPE_F32; + else if (s == "bf16") tgt = GGML_TYPE_BF16; + else if (s == "q8_0" || s == "q8") tgt = GGML_TYPE_Q8_0; + else if (s == "q4_0") tgt = GGML_TYPE_Q4_0; + std::printf("[retype] forcing weights to %s\n", ggml_type_name(tgt)); + loader.set_wtype_override(tgt); + } + + // Rename text_encoders.llm.* -> text_encoder.* (matches LTX-2 flow). + auto& tsm = loader.get_tensor_storage_map(); + { + const std::string from = "text_encoders.llm."; + const std::string to = "text_encoder."; + String2TensorStorage out; + for (auto& kv : tsm) { + std::string k = kv.first; + if (k.rfind(from, 0) == 0) { + k = to + k.substr(from.size()); + kv.second.name = k; + } + out[k] = std::move(kv.second); + } + tsm.swap(out); + } + // Gemma sandwich-norm renames (mirrored from stable-diffusion.cpp init). + auto rename_suffix = [&](const std::string& old_suffix, const std::string& new_suffix) { + String2TensorStorage out; + for (auto& kv : tsm) { + std::string k = kv.first; + size_t p = k.rfind(old_suffix); + if (p != std::string::npos && p + old_suffix.size() == k.size() && + k.find("text_encoder.model.layers.") != std::string::npos) { + k = k.substr(0, p) + new_suffix; + kv.second.name = k; + } + out[k] = std::move(kv.second); + } + tsm.swap(out); + }; + rename_suffix(".post_attention_layernorm.weight", ".pre_feedforward_layernorm.weight"); + rename_suffix(".post_attention_norm.weight", ".post_attention_layernorm.weight"); + rename_suffix(".post_ffw_norm.weight", ".post_feedforward_layernorm.weight"); + + LLM::LLMRunner runner(LLM::LLMArch::GEMMA3, backend, /*offload=*/false, + tsm, /*prefix=*/"text_encoder", /*enable_vision=*/false); + + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "text_encoder"); + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed\n"); + std::exit(1); + } + + // Dump token_embd weight rows to compare storage and sanity. + auto it = param_tensors.find("text_encoder.model.embed_tokens.weight"); + if (it != param_tensors.end()) { + ggml_tensor* w = it->second; + std::printf("[weight] embed_tokens.weight: type=%s ne=[%ld,%ld,%ld,%ld] nbytes=%zu\n", + ggml_type_name(w->type), (long)w->ne[0], (long)w->ne[1], (long)w->ne[2], (long)w->ne[3], ggml_nbytes(w)); + if (w->type == GGML_TYPE_F32) { + int64_t hidden = w->ne[0]; + for (int64_t row_idx : {0, 1, 2, 100, 106, 262207}) { + std::vector row(hidden); + ggml_backend_tensor_get(w, row.data(), (size_t)row_idx * hidden * sizeof(float), hidden * sizeof(float)); + double sum_abs = 0; + for (float v : row) sum_abs += std::fabs(v); + std::printf("[weight] row %6ld first 4: %+.4e %+.4e %+.4e %+.4e mean_abs=%.3e\n", + (long)row_idx, row[0], row[1], row[2], row[3], sum_abs / hidden); + } + } + } + + const int64_t T = static_cast(tokens.size()); + sd::Tensor input_ids({T, 1}); + for (int64_t i = 0; i < T; ++i) input_ids.data()[i] = tokens[i]; + sd::Tensor empty_mask; + + std::vector taps; + ::g_layer0_taps = &taps; + ::g_attn_layer0_taps = &taps; // share so attention internals also get captured + ::g_attn_tap_count = 0; // reset per-graph budget + fprintf(stderr, "[test] taps vec @ %p, g_attn_layer0_taps @ %p, set to %p\n", + (void*)&taps, (void*)&::g_attn_layer0_taps, (void*)::g_attn_layer0_taps); + auto stacked = runner.compute_all_hidden_states(/*n_threads=*/4, input_ids, empty_mask); + fprintf(stderr, "[test] after compute, taps.size()=%zu\n", taps.size()); + + // Collect tap dumps immediately while compute buffer is still alive. + std::vector tap_dumps; + for (auto* t : taps) { + const char* nm = ggml_get_name(t); + std::fprintf(stderr, "[tap-collect] tensor=%p name='%s' buffer=%p type=%s\n", + (void*)t, nm ? nm : "(null)", (void*)t->buffer, ggml_type_name(t->type)); + if (!nm || std::strncmp(nm, "DBG:", 4) != 0) continue; + if (!t->buffer) { + std::fprintf(stderr, "[tap] %s: no buffer (allocator aliased)\n", nm); + continue; + } + TapDump td; + td.name = nm + 4; + td.type = t->type; + for (int i = 0; i < 4; ++i) td.ne.push_back(t->ne[i]); + td.data = fetch(t); + tap_dumps.push_back(std::move(td)); + } + ::g_layer0_taps = nullptr; + ::g_attn_layer0_taps = nullptr; + + // Slice each layer out of the stacked tensor. stacked layout (innermost + // first): ne=[N+1, H, T, B]. For layer l: value at (b,t,h,l). + const int64_t L = runner.params.num_layers + 1; + const int64_t H = runner.params.hidden_size; + const int64_t Tdim = T; + const int64_t B = 1; + const int64_t per_layer = H * Tdim * B; + + std::vector dumps; + dumps.reserve(L); + for (int64_t l = 0; l < L; ++l) { + TapDump d; + d.name = (l == 0) ? "stacked_L00" : ("stacked_L" + std::to_string(l)); + d.type = GGML_TYPE_F32; + d.ne = {H, Tdim, B, 1}; + d.data.resize(per_layer * sizeof(float)); + float* out = reinterpret_cast(d.data.data()); + const float* src = stacked.data(); + for (int64_t b = 0; b < B; ++b) { + for (int64_t t = 0; t < Tdim; ++t) { + for (int64_t h = 0; h < H; ++h) { + int64_t idx_stacked = ((b * Tdim + t) * H + h) * L + l; + out[(b * Tdim + t) * H + h] = src[idx_stacked]; + } + } + } + dumps.push_back(std::move(d)); + } + // Append tap dumps so the caller can diff per-op. + for (auto& td : tap_dumps) dumps.push_back(std::move(td)); + return dumps; +} + +} // namespace + +int main(int argc, char** argv) { + if (argc < 2) { + std::fprintf(stderr, "usage: %s [cuda_device]\n", argv[0]); + return 2; + } + const std::string model_path = argv[1]; + int cuda_device = argc >= 3 ? std::atoi(argv[2]) : 0; + + // Default short prompt; SD_GEMMA_TEST_LEN env var pads to a target length + // (defaults to 128 to match LTX-2's runtime padding). This forces the + // matmul dispatch into the MMQ path where batch > 8. + std::vector tokens = {2, 106, 108, 1055, 674, 25148, 110, 107}; + { + int target_len = 128; + if (const char* e = std::getenv("SD_GEMMA_TEST_LEN")) target_len = std::atoi(e); + if (target_len > (int) tokens.size()) { + tokens.resize(target_len, /*pad token id=*/0); + } + } + + std::printf("[run] CPU forward...\n"); + std::fflush(stdout); + ggml_backend_t cpu_backend = ggml_backend_cpu_init(); + auto cpu_dumps = run_and_dump(model_path, cpu_backend, tokens); + ggml_backend_free(cpu_backend); + std::printf("[run] CPU done, %zu taps\n", cpu_dumps.size()); + +#ifdef SD_USE_CUDA + std::printf("[run] CUDA (device %d) forward...\n", cuda_device); + std::fflush(stdout); + ggml_backend_t cuda_backend = ggml_backend_cuda_init(cuda_device); + if (!cuda_backend) { + std::fprintf(stderr, "fatal: CUDA backend init failed for device %d\n", cuda_device); + return 1; + } + auto cuda_dumps = run_and_dump(model_path, cuda_backend, tokens); + ggml_backend_free(cuda_backend); + std::printf("[run] CUDA done, %zu taps\n", cuda_dumps.size()); +#else + std::fprintf(stderr, "fatal: built without SD_USE_CUDA\n"); + return 1; +#endif + + // Optional: dump specific CPU-side taps to disk for the standalone + // mul_mat parity test. SD_GEMMA_DUMP_DIR=/tmp/gemma_taps writes + // .bin for each tap. We also write a small .shape text file. + if (const char* dir = std::getenv("SD_GEMMA_DUMP_DIR")) { + for (const auto& d : cpu_dumps) { + std::string fn = std::string(dir) + "/" + d.name + ".bin"; + FILE* fp = std::fopen(fn.c_str(), "wb"); + if (fp) { + std::fwrite(d.data.data(), 1, d.data.size(), fp); + std::fclose(fp); + } + std::string fn2 = std::string(dir) + "/" + d.name + ".shape"; + FILE* fp2 = std::fopen(fn2.c_str(), "w"); + if (fp2) { + std::fprintf(fp2, "%ld %ld %ld %ld %s\n", + (long) d.ne[0], (long) d.ne[1], (long) d.ne[2], (long) d.ne[3], + ggml_type_name(d.type)); + std::fclose(fp2); + } + } + std::printf("[dump] wrote %zu CPU taps to %s/\n", cpu_dumps.size(), dir); + } + + // Diff by name. + std::map cpu_idx; + for (const auto& d : cpu_dumps) cpu_idx[d.name] = &d; + + std::printf("\n%-22s %-5s %12s %12s %12s %6s\n", + "tap", "type", "max_abs", "mean_abs", "cpu_mean_mag", "shape"); + int fail_count = 0; + for (const auto& c : cuda_dumps) { + std::fprintf(stderr, "[diff] examining tap '%s' type=%s ne=[%ld,%ld,%ld,%ld]\n", + c.name.c_str(), ggml_type_name(c.type), + (long)c.ne[0], (long)c.ne[1], (long)c.ne[2], (long)c.ne[3]); + auto it = cpu_idx.find(c.name); + if (it == cpu_idx.end()) { + std::printf(" %-20s [missing on CPU side]\n", c.name.c_str()); + continue; + } + const TapDump* p = it->second; + if (p->type != c.type || p->ne != c.ne) { + std::printf(" %-20s type/shape mismatch\n", c.name.c_str()); + continue; + } + if (c.type != GGML_TYPE_F32) { + // Cast to F32 for diffing if needed. For simplicity we only handle F32 here. + std::printf(" %-20s type=%s skipped\n", c.name.c_str(), ggml_type_name(c.type)); + continue; + } + int64_t n = int64_t(p->data.size() / sizeof(float)); + auto s = diff_f32( + reinterpret_cast(p->data.data()), + reinterpret_cast(c.data.data()), + n); + double cpu_mag = 0.0; + const float* cp = reinterpret_cast(p->data.data()); + for (int64_t i = 0; i < n; ++i) cpu_mag += std::fabs(cp[i]); + cpu_mag /= (n > 0 ? n : 1); + bool fail = (s.max_abs > 1e-3f * (float)cpu_mag + 1e-4f); + std::printf(" %-20s %-5s %12.3e %12.3e %12.3e [%ld,%ld,%ld,%ld] %s\n", + c.name.c_str(), ggml_type_name(c.type), + s.max_abs, s.mean_abs, (double)cpu_mag, + (long)c.ne[0], (long)c.ne[1], (long)c.ne[2], (long)c.ne[3], + fail ? "FAIL" : "ok"); + if (fail) { + fail_count++; + // First-fail detail dump: first 8 values from each side. + if (fail_count == 1) { + const float* cp = reinterpret_cast(p->data.data()); + const float* cu = reinterpret_cast(c.data.data()); + std::printf(" first 8 floats: CPU vs CUDA\n"); + for (int64_t i = 0; i < 8 && i < n; ++i) { + std::printf(" [%ld] %+.6e vs %+.6e (diff %+.3e)\n", + (long)i, cp[i], cu[i], cu[i] - cp[i]); + } + std::printf(" argmax element: CPU=%+.6e CUDA=%+.6e idx=%ld\n", + cp[s.argmax], cu[s.argmax], (long)s.argmax); + } + } + } + std::printf("\n%d taps diverged (max_abs > 1e-3 × mean(|cpu|) + 1e-4).\n", fail_count); + return fail_count == 0 ? 0 : 3; +} diff --git a/tests/ltx_parity/test_gemma_parity.cpp b/tests/ltx_parity/test_gemma_parity.cpp new file mode 100644 index 000000000..c43c2e567 --- /dev/null +++ b/tests/ltx_parity/test_gemma_parity.cpp @@ -0,0 +1,287 @@ +// Gemma 3 C++ parity test. +// +// Loads /tmp/gemma_ref/{state_dict.safetensors, tensors/*.bin} produced by +// tests/ltx_parity/dump_gemma.py, runs one LLMRunner forward pass on the same +// input_ids, and diffs each of the N+1 hidden states (embedding + per-layer + +// post-final-norm last) against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "llm.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/gemma_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // --- 1. Load state dict. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + // Skip convert_tensors_name(): the "text_encoder" prefix is remapped to + // "cond_stage_model.transformer." by the conversion table (see name_conversion.cpp:1112), + // which would break our direct-load parity test. We match key names exactly. + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct LLMRunner with GEMMA3 arch. Hyperparams auto-detect from tensor shapes. + ggml_backend_t backend = ggml_backend_cpu_init(); + LLM::LLMRunner runner(LLM::LLMArch::GEMMA3, backend, /*offload_params_to_cpu=*/false, + tsm, /*prefix=*/"text_encoder", /*enable_vision=*/false); + + const auto& p = runner.params; + std::printf("[config] layers=%ld hidden=%ld heads=%d kv_heads=%d head_dim=%d " + "ff=%ld vocab=%ld sw=%d pattern=%d\n", + p.num_layers, p.hidden_size, p.num_heads, p.num_kv_heads, p.head_dim, + p.intermediate_size, p.vocab_size, p.sliding_window, p.sliding_window_pattern); + + // --- 3. Load params buffer and weights. + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "text_encoder"); + std::printf("[load] loading %zu param tensors…\n", param_tensors.size()); + + // Collect tsm keys for diffing. + std::set tsm_keys; + for (const auto& kv : tsm) tsm_keys.insert(kv.first); + + // Dump any param_tensor keys not present in the file (for diagnosing name mismatches). + int missing_shown = 0; + for (const auto& pt : param_tensors) { + if (tsm_keys.find(pt.first) == tsm_keys.end()) { + if (missing_shown < 5) { + std::printf("[load] missing in file: %s\n", pt.first.c_str()); + missing_shown++; + } + } + } + + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed (some names unmatched?)\n"); + return 1; + } + + // --- 4. Load reference inputs. Dumper saved input_ids as f32 for simplicity. + const int64_t B = 1, T = 8, H = p.hidden_size; + auto input_ids_f32 = load_raw_bin(ref_dir + "/tensors/input_ids.bin", {T, B}); + sd::Tensor input_ids({T, B}); + for (int64_t i = 0; i < T; ++i) { + input_ids.data()[i] = (int32_t)input_ids_f32.data()[i]; + } + sd::Tensor empty_mask; + + std::printf("[input] input_ids: "); + for (int i = 0; i < T; i++) std::printf("%d ", input_ids.data()[i]); + std::printf("\n"); + + // Override the window size to match the tiny-config dump. + runner.params.sliding_window = 4; + // The tiny-config Python dump does NOT apply linear rope_scaling (that's only + // wired for the "deep" variant which mirrors the real 12B). Disable scaling + // on the C++ side so we compare apples-to-apples. + runner.params.rope_scaling_factor_global = 1.0f; + + // --- 5a. First test the basic forward path (returns just last_hidden_state after norm). + std::printf("[compute] basic forward (last_hidden_state only)…\n"); + std::fflush(stdout); + auto basic = runner.compute(/*n_threads=*/1, input_ids, empty_mask, {}, {}); + std::printf("[compute] basic forward done, numel=%ld first=%.4f\n", basic.numel(), basic.numel() > 0 ? basic.data()[0] : 0.f); + std::fflush(stdout); + + // --- 5b. Compute all N+1 hidden states. + std::printf("[compute] running forward pass with all-hidden-states path…\n"); + std::fflush(stdout); + auto stacked = runner.compute_all_hidden_states(/*n_threads=*/1, input_ids, empty_mask); + std::printf("[compute] done, stacked shape=[%zu dims] numel=%ld\n", stacked.shape().size(), stacked.numel()); + std::fflush(stdout); + + // stacked has shape (sd::Tensor layout = innermost-first): {N+1, H, T, B}. + const int64_t N_plus_1 = p.num_layers + 1; + if (stacked.numel() != N_plus_1 * H * T * B) { + std::fprintf(stderr, "fatal: stacked numel mismatch got=%ld expected=%ld\n", + stacked.numel(), N_plus_1 * H * T * B); + return 1; + } + std::printf("[output] stacked shape=[%ld,%ld,%ld,%ld] numel=%ld\n", + N_plus_1, H, T, B, stacked.numel()); + + // --- 6. Slice each layer out of the stacked tensor and diff. + // Memory layout: innermost=N+1, so all layers for one (h, t, b) are adjacent. + // For a given layer_idx l: layer_data[b][t][h] = stacked[((b*T + t)*H + h)*(N+1) + l]. + // Ref layer is stored with innermost=H, shape [B, T, H]. So we reconstruct layer l by + // scattering. + // Tolerances reflect the fp16 cast inside ggml_ext_attention_ext (K/V go through + // GGML_TYPE_F16 before the softmax). Reference Python stays in fp32, so ~1e-3 abs + // drift per attention layer is baked in. For 6 stacked layers we budget ~6× that. + // max_rel is skipped — small reference values blow up relative error even when + // absolute agreement is fine. + const float tol_max_abs = 1e-2f; + const float tol_mean_abs = 2e-3f; + + bool all_pass = true; + std::printf("\n=== Gemma hidden-state parity ===\n"); + std::printf("%-18s %11s %11s %11s\n", "tag", "max_abs", "mean_abs", "max_rel"); + + std::vector layer_buf(B * T * H); + for (int l = 0; l < N_plus_1; l++) { + // Gather layer l from the stacked tensor. + const float* src = stacked.data(); + for (int64_t b = 0; b < B; b++) { + for (int64_t t = 0; t < T; t++) { + for (int64_t h = 0; h < H; h++) { + int64_t stacked_idx = ((b * T + t) * H + h) * N_plus_1 + l; + int64_t ref_idx = (b * T + t) * H + h; + layer_buf[ref_idx] = src[stacked_idx]; + } + } + } + std::string tag = (l == 0) ? "hs_embed" : ("hs_" + std::string(l < 10 ? "0" : "") + std::to_string(l)); + auto ref = load_raw_bin(ref_dir + "/tensors/" + tag + ".bin", {H, T, B}); + auto s = diff_fp32(layer_buf.data(), ref.data(), (int64_t)layer_buf.size()); + bool pass = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf(" %-16s %.3e %.3e %.3e %s\n", + tag.c_str(), s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + all_pass &= pass; + } + std::printf("\n%s (tol: max_abs<%.1e mean_abs<%.1e)\n", + all_pass ? "Gemma parity: PASS" : "Gemma parity: FAIL", + tol_max_abs, tol_mean_abs); + + // --- Deep variant parity: 24 layers × 512 hidden, seq=32 with 16-wide sliding --- + // Mirrors the real Gemma 3 12B's sliding_window_pattern=6 (so ~every 6th layer does + // global attention) at scaled-down dims. Catches drift patterns that only appear + // across many layers / real hidden-size, without requiring the full 12B download. + bool deep_present = false; + for (const auto& k : tsm_keys) { + if (k.rfind("text_encoder_deep.", 0) == 0) { + deep_present = true; + break; + } + } + if (!deep_present) { + std::printf("\n[deep] no text_encoder_deep.* tensors found (run " + "`GEMMA_PARITY_VARIANT=deep dump_gemma.py` to enable); skipping\n"); + return all_pass ? 0 : 3; + } + + std::printf("\n=== Gemma deep parity (24L × 512H, sliding=16, seq=32) ===\n"); + LLM::LLMRunner deep_runner(LLM::LLMArch::GEMMA3, backend, /*offload=*/false, + tsm, /*prefix=*/"text_encoder_deep", /*enable_vision=*/false); + const auto& dp = deep_runner.params; + std::printf("[deep config] layers=%ld hidden=%ld heads=%d kv_heads=%d head_dim=%d ff=%ld\n", + dp.num_layers, dp.hidden_size, dp.num_heads, dp.num_kv_heads, dp.head_dim, + dp.intermediate_size); + + deep_runner.alloc_params_buffer(); + std::map deep_params; + deep_runner.get_param_tensors(deep_params, "text_encoder_deep"); + if (!loader.load_tensors(deep_params)) { + std::fprintf(stderr, "fatal: deep load_tensors failed\n"); + return 1; + } + + const int64_t Td = 32; + const int64_t Hd = dp.hidden_size; + auto deep_input_ids_f32 = load_raw_bin(ref_dir + "/tensors/deep_input_ids.bin", {Td, 1}); + sd::Tensor deep_input_ids({Td, 1}); + for (int64_t i = 0; i < Td; ++i) deep_input_ids.data()[i] = (int32_t)deep_input_ids_f32.data()[i]; + sd::Tensor deep_empty_mask; + + // Override sliding window to match the deep variant's config (tiny: 4, deep: 16). + deep_runner.params.sliding_window = 16; + + auto deep_stacked = deep_runner.compute_all_hidden_states(/*n_threads=*/1, + deep_input_ids, + deep_empty_mask); + const int64_t deep_N_plus_1 = dp.num_layers + 1; + GGML_ASSERT(deep_stacked.numel() == deep_N_plus_1 * Hd * Td * 1); + + std::printf("[deep output] stacked shape=[%ld,%ld,%ld,1] numel=%ld\n", + deep_N_plus_1, Hd, Td, deep_stacked.numel()); + + const float deep_tol_max_abs = 5e-2f; // 24 layers → ~4× baseline drift budget + const float deep_tol_mean_abs = 1e-2f; + bool deep_all_pass = true; + std::printf("%-22s %11s %11s %11s\n", "tag", "max_abs", "mean_abs", "max_rel"); + + std::vector deep_layer_buf(Td * Hd); + for (int l = 0; l < deep_N_plus_1; ++l) { + const float* src = deep_stacked.data(); + for (int64_t t = 0; t < Td; ++t) { + for (int64_t h = 0; h < Hd; ++h) { + int64_t stacked_idx = (t * Hd + h) * deep_N_plus_1 + l; + deep_layer_buf[t * Hd + h] = src[stacked_idx]; + } + } + std::string tag = (l == 0) ? "deep_hs_embed" : ("deep_hs_" + std::string(l < 10 ? "0" : "") + std::to_string(l)); + auto ref = load_raw_bin(ref_dir + "/tensors/" + tag + ".bin", {Hd, Td, 1}); + auto s = diff_fp32(deep_layer_buf.data(), ref.data(), (int64_t)deep_layer_buf.size()); + bool pass = s.max_abs < deep_tol_max_abs && s.mean_abs < deep_tol_mean_abs; + std::printf(" %-20s %.3e %.3e %.3e %s\n", + tag.c_str(), s.max_abs, s.mean_abs, s.max_rel, pass ? "PASS" : "FAIL"); + deep_all_pass &= pass; + } + std::printf("\n%s (tol: max_abs<%.1e mean_abs<%.1e)\n", + deep_all_pass ? "Gemma deep parity: PASS" : "Gemma deep parity: FAIL", + deep_tol_max_abs, deep_tol_mean_abs); + + return (all_pass && deep_all_pass) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_gemma_tokenizer.cpp b/tests/ltx_parity/test_gemma_tokenizer.cpp new file mode 100644 index 000000000..6ce61b28f --- /dev/null +++ b/tests/ltx_parity/test_gemma_tokenizer.cpp @@ -0,0 +1,88 @@ +// Tokenizer parity test for Gemma 3. +// +// Encodes a handful of fixed strings with our GemmaTokenizer and compares to the +// token IDs produced by transformers' AutoTokenizer (google/gemma-3-12b-it). The +// expected IDs below are hard-coded from a Python reference run — if the HF vocab +// ever changes they must be regenerated. + +#include +#include +#include +#include + +#include "tokenizers/gemma_tokenizer.h" + +namespace { + +struct Case { + std::string input; + std::vector expected_no_bos; +}; + +bool run_case(GemmaTokenizer& tk, const Case& c, int idx) { + std::vector got = tk.encode(c.input); + bool ok = (got == c.expected_no_bos); + std::printf(" [%2d] ", idx); + if (ok) { + std::printf("PASS %zu tokens: ", got.size()); + } else { + std::printf("FAIL got=%zu exp=%zu ", got.size(), c.expected_no_bos.size()); + } + for (size_t i = 0; i < got.size() && i < 8; i++) std::printf("%d ", got[i]); + if (got.size() > 8) std::printf("..."); + std::printf("\n"); + if (!ok) { + std::printf(" input : %s\n", c.input.c_str()); + std::printf(" expected : "); + for (int x : c.expected_no_bos) std::printf("%d ", x); + std::printf("\n got : "); + for (int x : got) std::printf("%d ", x); + std::printf("\n"); + } + return ok; +} + +} // namespace + +int main(int argc, char** argv) { + const char* default_path = + "/home/ilintar/.cache/huggingface/hub/models--google--gemma-3-12b-it/" + "snapshots/96b6f1eccf38110c56df3a15bffe176da04bfd80/tokenizer.json"; + std::string path = (argc > 1) ? argv[1] : default_path; + + GemmaTokenizer tk; + std::printf("[load] %s\n", path.c_str()); + if (!tk.load_from_file(path)) { + std::fprintf(stderr, "fatal: could not load tokenizer\n"); + return 1; + } + std::printf("[load] vocab=%d bos=%d eos=%d pad=%d unk=%d\n", + tk.vocab_size(), tk.BOS_TOKEN_ID, tk.EOS_TOKEN_ID, tk.PAD_TOKEN_ID, tk.UNK_TOKEN_ID); + + // Ground truth from transformers.AutoTokenizer("google/gemma-3-12b-it") with + // add_special_tokens=False. + std::vector cases = { + {"hello", {23391}}, + {"hello world", {23391, 1902}}, + {" a b", {138, 236746, 138, 236763}}, + {"naïve", {1789, 238527, 560}}, + {"你好", {144626}}, + {"→ a", {238183, 496}}, + {"The quick brown fox jumps over the lazy dog.", + {818, 3823, 8864, 37423, 38167, 1024, 506, 31770, 4799, 236761}}, + {"", {}}, + {" ", {236743}}, + {"\n\n\ttabs and\nnewlines", + {108, 255968, 39218, 532, 107, 208697}}, + {"mixed: ABCdef 123 !@# UNK char: \xe2\x80\x8b", + {63258, 236787, 21593, 2063, 236743, 236770, 236778, 236800, + 1717, 236940, 236865, 7866, 236855, 1577, 236787, 36504}}, + }; + + int pass = 0; + for (size_t i = 0; i < cases.size(); i++) { + if (run_case(tk, cases[i], (int)i)) pass++; + } + std::printf("\n%d / %zu cases passed.\n", pass, cases.size()); + return (pass == (int)cases.size()) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp b/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp new file mode 100644 index 000000000..d7d9ae5c4 --- /dev/null +++ b/tests/ltx_parity/test_ltx2_vae_roundtrip.cpp @@ -0,0 +1,263 @@ +// LTX-2 VAE encode→decode round-trip sanity check on a real 22B VAE checkpoint. +// +// Purpose: rule out whether the blocky output from the end-to-end LTX-2 pipeline +// is caused by a broken VAE decoder. Constructs a simple synthetic video +// (color gradient ramps), runs the real LTX-2 VAE through encode→decode, and +// reports the reconstruction MSE + dumps the first output frame's values. +// +// If MSE is small (<0.05 for bounded [-1,1] input), the VAE is sound and the +// structural issue must live upstream (DiT / conditioning). If MSE is high, +// the VAE itself is miscomputing and that explains the pipeline output. +// +// Usage: sd-ltx2-vae-roundtrip [WIDTH [HEIGHT [FRAMES]]] + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" + +#include "ltxvae.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +void apply_pcs_duplication(String2TensorStorage& tsm) { + // Mirror stable-diffusion.cpp's LTX-2-specific duplication so the nested + // encoder.per_channel_statistics.* / decoder.per_channel_statistics.* paths + // that our VideoEncoder/Decoder blocks look up exist. + const std::string top_pre = "first_stage_model.per_channel_statistics."; + std::vector> to_copy; + for (const auto& kv : tsm) { + const std::string& k = kv.first; + if (k.rfind(top_pre, 0) == 0) { + to_copy.push_back({k, k.substr(top_pre.size())}); + } + } + size_t copied = 0; + for (auto& pair : to_copy) { + auto src_it = tsm.find(pair.first); + if (src_it == tsm.end()) continue; + for (const char* sub : {"encoder", "decoder"}) { + std::string dst = "first_stage_model." + std::string(sub) + + ".per_channel_statistics." + pair.second; + if (tsm.find(dst) != tsm.end()) continue; + TensorStorage dup = src_it->second; + dup.name = dst; + tsm[dst] = dup; + copied++; + } + } + std::printf("[pcs] duplicated %zu entries\n", copied); +} + +// Builds a 9-frame [W, H, T, 3] synthetic video in [-1, 1]. Produces a +// spatial gradient ramp in R/G/B channels so reconstruction is easy to eyeball. +sd::Tensor make_synthetic_video(int W, int H, int T) { + sd::Tensor v({W, H, T, 3}); + for (int t = 0; t < T; ++t) { + float tphase = static_cast(t) / std::max(T - 1, 1); + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + float r = static_cast(w) / std::max(W - 1, 1) * 2.0f - 1.0f; + float g = static_cast(h) / std::max(H - 1, 1) * 2.0f - 1.0f; + float b = tphase * 2.0f - 1.0f; + v.index(w, h, t, 0) = r; + v.index(w, h, t, 1) = g; + v.index(w, h, t, 2) = b; + } + } + } + return v; +} + +struct DiffStats { + float max_abs = 0.f, mean_abs = 0.f, mse = 0.f; +}; +DiffStats diff_stats(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0, sum_sq = 0.0; + for (int64_t i = 0; i < n; ++i) { + float d = a[i] - b[i]; + float ad = std::fabs(d); + s.max_abs = std::max(s.max_abs, ad); + sum_abs += ad; + sum_sq += static_cast(d) * d; + } + s.mean_abs = static_cast(sum_abs / std::max(n, 1)); + s.mse = static_cast(sum_sq / std::max(n, 1)); + return s; +} + +} // namespace + +int main(int argc, char** argv) { + sd_set_log_callback( + [](enum sd_log_level_t /*level*/, const char* text, void* /*data*/) { + std::fputs(text, stderr); + }, + nullptr); + + const char* vae_path = (argc >= 2) + ? argv[1] + : "/media/ilintar/D_SSD/models/ltx-2/ltx-2.3-22b-dev_video_vae.safetensors"; + int W = (argc >= 3) ? std::atoi(argv[2]) : 128; + int H = (argc >= 4) ? std::atoi(argv[3]) : 128; + int T = (argc >= 5) ? std::atoi(argv[4]) : 9; + + std::printf("[cfg] vae_path = %s\n", vae_path); + std::printf("[cfg] input video = %dx%d, %d frames\n", W, H, T); + + ModelLoader loader; + // The raw 22B video_vae.safetensors ships tensors as `encoder.*`, `decoder.*`, + // `per_channel_statistics.*` with no top-level prefix. Passing prefix="vae." on + // init adds that so the subsequent convert_tensors_name() remaps `vae.` → + // `first_stage_model.` via name_conversion.cpp. + if (!loader.init_from_file(vae_path, "vae.")) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", vae_path); + return 1; + } + loader.convert_tensors_name(); + auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors\n", tsm.size()); + apply_pcs_duplication(tsm); + + // Allow toggling timestep conditioning from env so we can A/B the prod config + // (the real 22B checkpoint is timestep_conditioned). + bool tcond = std::getenv("VAE_TIMESTEP_COND") != nullptr; + std::printf("[cfg] timestep_cond = %s\n", tcond ? "true" : "false"); + + ggml_backend_t backend = ggml_backend_cpu_init(); + LTXVAE::LTX2VAERunner vae(backend, /*offload=*/false, tsm, + /*prefix=*/"first_stage_model", + VERSION_LTX2, + /*in_ch=*/3, /*latent_ch=*/128, /*patch=*/4, + /*decoder_base_ch=*/128, /*timestep_cond=*/tcond, + LTXVAE::LTX2VAERunner::ltx2_22b_enc_specs(), + LTXVAE::LTX2VAERunner::ltx2_22b_dec_specs()); + vae.alloc_params_buffer(); + std::map vae_params; + vae.get_param_tensors(vae_params, "first_stage_model"); + std::printf("[vae] requesting %zu param tensors\n", vae_params.size()); + if (!loader.load_tensors(vae_params)) { + std::fprintf(stderr, "fatal: vae load_tensors failed (weights unmatched?)\n"); + return 1; + } + + // Build synthetic [W, H, T, 3] video in [-1, 1]. + auto video = make_synthetic_video(W, H, T); + std::printf("[input] shape = [W=%d, H=%d, T=%d, C=3] min=%.3f max=%.3f mean=%.3f\n", + W, H, T, *std::min_element(video.data(), video.data() + video.numel()), + *std::max_element(video.data(), video.data() + video.numel()), + [&]() { + double s = 0; + for (int64_t i = 0; i < video.numel(); ++i) s += video.data()[i]; + return static_cast(s / video.numel()); + }()); + + // --- Encode --- + std::printf("[encode] running…\n"); + auto latent = vae._compute(/*n_threads=*/1, video, /*decode_graph=*/false); + std::printf("[encode] latent shape = ["); + for (size_t i = 0; i < latent.shape().size(); ++i) + std::printf("%s%lld", (i ? ", " : ""), (long long)latent.shape()[i]); + std::printf("] numel=%lld\n", (long long)latent.numel()); + if (latent.empty()) { + std::fprintf(stderr, "fatal: encode produced empty output\n"); + return 2; + } + + // Latent first 8 values for eyeballing. + std::printf("[encode] first 8 latent values: "); + for (int i = 0; i < 8 && i < latent.numel(); ++i) + std::printf("%+.3f ", latent.data()[i]); + std::printf("\n"); + + // Encoder's output layout is [W_lat, H_lat, T_lat, C_lat]. The decoder's + // expected input is the same layout. + // --- Decode --- + std::printf("[decode] running…\n"); + auto recon = vae._compute(/*n_threads=*/1, latent, /*decode_graph=*/true); + std::printf("[decode] recon shape = ["); + for (size_t i = 0; i < recon.shape().size(); ++i) + std::printf("%s%lld", (i ? ", " : ""), (long long)recon.shape()[i]); + std::printf("] numel=%lld\n", (long long)recon.numel()); + if (recon.empty()) { + std::fprintf(stderr, "fatal: decode produced empty output\n"); + return 3; + } + + if (recon.numel() != video.numel()) { + std::fprintf(stderr, "fatal: recon numel %lld != input numel %lld " + "(enc/dec changed element count)\n", + (long long)recon.numel(), (long long)video.numel()); + return 4; + } + + std::printf("[decode] first 8 recon values: "); + for (int i = 0; i < 8 && i < recon.numel(); ++i) + std::printf("%+.3f ", recon.data()[i]); + std::printf("\n[input ] first 8 input values: "); + for (int i = 0; i < 8 && i < video.numel(); ++i) + std::printf("%+.3f ", video.data()[i]); + std::printf("\n"); + + // Per-channel mean(input) - mean(recon) bias diagnostic. The encoder→decoder + // round-trip should be near-zero biased. If a per-channel bias is visible, the + // VAE is shifting the output range and that explains "everything looks dark". + // sd::Tensor layout is [W, H, T, C] (ggml memory order), C is outermost. + auto sh = video.shape(); // {W, H, T, 3} from make_synthetic_video + int64_t W_in = sh[0], H_in = sh[1], T_in = sh[2], C_in = sh[3]; + auto rsh = recon.shape(); + int64_t W_o = rsh[0], H_o = rsh[1], T_o = rsh[2], C_o = (rsh.size() >= 4 ? rsh[3] : 3); + std::printf("\n=== per-channel bias (recon - input) ===\n"); + std::printf(" ch in_mean recon_mean bias in_std recon_std\n"); + auto chmean = [](const float* d, int64_t W, int64_t H, int64_t T, int64_t C, int64_t c) { + double s = 0; int64_t n = W * H * T; + for (int64_t t = 0; t < T; ++t) + for (int64_t h = 0; h < H; ++h) + for (int64_t w = 0; w < W; ++w) + s += d[((c * T + t) * H + h) * W + w]; + return s / std::max(n, 1); + }; + auto chstd = [](const float* d, int64_t W, int64_t H, int64_t T, int64_t C, int64_t c, double mean) { + double s = 0; int64_t n = W * H * T; + for (int64_t t = 0; t < T; ++t) + for (int64_t h = 0; h < H; ++h) + for (int64_t w = 0; w < W; ++w) { + double v = d[((c * T + t) * H + h) * W + w] - mean; + s += v * v; + } + return std::sqrt(s / std::max(n, 1)); + }; + for (int64_t c = 0; c < std::min(C_in, C_o); ++c) { + double im = chmean(video.data(), W_in, H_in, T_in, C_in, c); + double rm = chmean(recon.data(), W_o, H_o, T_o, C_o, c); + double is = chstd (video.data(), W_in, H_in, T_in, C_in, c, im); + double rs = chstd (recon.data(), W_o, H_o, T_o, C_o, c, rm); + std::printf(" %lld %+.4f %+.4f %+.4f %.4f %.4f\n", + (long long)c, im, rm, rm - im, is, rs); + } + + // Diff. + auto s = diff_stats(recon.data(), video.data(), recon.numel()); + std::printf("\n=== round-trip diff ===\n"); + std::printf(" max_abs = %.3e\n", s.max_abs); + std::printf(" mean_abs = %.3e\n", s.mean_abs); + std::printf(" mse = %.3e\n", s.mse); + + // Loose pass thresholds. LTX-2 VAE is lossy but mean_abs <0.1 for a smooth + // gradient is a reasonable ceiling. Anything much worse means structural + // divergence, not just compression. + const float tol_mse = 0.05f; + bool pass = s.mse < tol_mse; + std::printf("%s (tol: mse < %.1e)\n", + pass ? "VAE round-trip: PASS" : "VAE round-trip: FAIL", + tol_mse); + return pass ? 0 : 5; +} diff --git a/tests/ltx_parity/test_ltx_parity.cpp b/tests/ltx_parity/test_ltx_parity.cpp new file mode 100644 index 000000000..3bfe14920 --- /dev/null +++ b/tests/ltx_parity/test_ltx_parity.cpp @@ -0,0 +1,438 @@ +// LTX-2 C++ parity test. +// +// Loads the state dict + reference intermediate tensors dumped by +// tests/ltx_parity/dump_reference.py, runs one forward pass of LTXRunner on the same inputs, +// and diffs the output against the Python reference. +// +// Tolerances: F32 backend is expected to match to ~1e-4 abs / ~1e-3 rel. Larger drift points to +// a block-level bug — rerun with --intermediate to capture per-block outputs via the cache API. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "denoiser.hpp" +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ltx.hpp" +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f; + float mean_abs = 0.f; + float max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { + s.max_abs = abs_err; + s.max_abs_idx = i; + } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +void print_shape(const char* label, const std::vector& shape) { + std::printf("%s[", label); + for (size_t i = 0; i < shape.size(); ++i) { + std::printf("%s%ld", i == 0 ? "" : ", ", shape[i]); + } + std::printf("]\n"); +} + +} // namespace + +// Returns true if all schedule cases agree within absolute tolerance. +bool check_schedule(const std::string& ref_dir) { + struct Case { + int tokens; + int steps; + bool stretch; + const char* label; + }; + // Must match the cases in dump_reference.py::dump_scheduler. + const std::vector cases = { + {1024, 10, true, "tokens1024_steps10_stretch1"}, + {1024, 30, true, "tokens1024_steps30_stretch1"}, + {4096, 10, true, "tokens4096_steps10_stretch1"}, + {4096, 40, true, "tokens4096_steps40_stretch1"}, + {2560, 30, true, "tokens2560_steps30_stretch1"}, + {4096, 8, false, "tokens4096_steps8_stretch0"}, + }; + + bool all_pass = true; + std::printf("\n=== LTX2FlowDenoiser::get_sigmas parity ===\n"); + for (const auto& c : cases) { + LTX2FlowDenoiser denoiser; + denoiser.stretch = c.stretch; + auto cpp_sigmas = denoiser.get_sigmas(static_cast(c.steps), c.tokens, + DISCRETE_SCHEDULER, VERSION_LTX2); + auto ref = load_raw_bin(ref_dir + "/tensors/schedule__" + c.label + ".bin", + {static_cast(c.steps + 1)}); + if (static_cast(cpp_sigmas.size()) != ref.numel()) { + std::fprintf(stderr, "[sched %s] size mismatch cpp=%zu ref=%ld\n", + c.label, cpp_sigmas.size(), ref.numel()); + all_pass = false; + continue; + } + auto s = diff_fp32(cpp_sigmas.data(), ref.data(), ref.numel()); + // Schedules are small floats (≤1), 1e-5 abs tolerance is reasonable; mu/exp arithmetic + // differs negligibly between libm and Python math. + const float tol = 5e-5f; + bool pass = s.max_abs < tol; + std::printf(" %-32s max_abs=%.2e mean_abs=%.2e %s\n", + c.label, s.max_abs, s.mean_abs, pass ? "PASS" : "FAIL"); + if (!pass) { + std::printf(" cpp[0..3] = %.6f %.6f %.6f %.6f\n", cpp_sigmas[0], + cpp_sigmas[1], cpp_sigmas[2], cpp_sigmas[3]); + std::printf(" ref[0..3] = %.6f %.6f %.6f %.6f\n", ref.data()[0], + ref.data()[1], ref.data()[2], ref.data()[3]); + all_pass = false; + } + } + return all_pass; +} + +// Runs one Euler step in C++ using LTX2FlowDenoiser's scheduler values + a DiT velocity output, +// then diffs against the Python reference. +bool check_euler_step(const std::string& ref_dir, LTX::LTXRunner& runner, + const sd::Tensor& latent, const sd::Tensor& context) { + auto sigma_cur_ref = load_raw_bin(ref_dir + "/tensors/euler__sigma_cur.bin", {1}); + auto sigma_next_ref = load_raw_bin(ref_dir + "/tensors/euler__sigma_next.bin", {1}); + auto v_ref = load_raw_bin(ref_dir + "/tensors/euler__v_step_unflat.bin", {6, 4, 2, 16}); + auto x_next_ref = load_raw_bin(ref_dir + "/tensors/euler__x_next_unflat.bin", {6, 4, 2, 16}); + + const float sigma_cur = sigma_cur_ref.data()[0]; + const float sigma_next = sigma_next_ref.data()[0]; + + std::printf("\n=== Euler step parity (σ=%.4f → %.4f) ===\n", sigma_cur, sigma_next); + + // Run the DiT at σ_cur (pre-scaled by 1000 for AdaLN, via LTX2FlowDenoiser::sigma_to_t). + LTX2FlowDenoiser denoiser; + sd::Tensor t_in({1}); + t_in.data()[0] = denoiser.sigma_to_t(sigma_cur); + + sd::Tensor empty_mask; + auto v_cpp = runner.compute(/*n_threads=*/1, latent, t_in, context, empty_mask); + if (v_cpp.numel() != v_ref.numel()) { + std::fprintf(stderr, "fatal: velocity size mismatch\n"); + return false; + } + + // Compute x_next = latent + (σ_next - σ) * v (element-wise). + sd::Tensor x_next_cpp(latent.shape()); + const float dt = sigma_next - sigma_cur; + for (int64_t i = 0; i < latent.numel(); ++i) { + x_next_cpp.data()[i] = latent.data()[i] + dt * v_cpp.data()[i]; + } + + auto sv = diff_fp32(v_cpp.data(), v_ref.data(), v_cpp.numel()); + auto sx = diff_fp32(x_next_cpp.data(), x_next_ref.data(), x_next_cpp.numel()); + + std::printf(" velocity@σ_cur: max_abs=%.2e mean_abs=%.2e max_rel=%.2e\n", + sv.max_abs, sv.mean_abs, sv.max_rel); + std::printf(" x_next: max_abs=%.2e mean_abs=%.2e max_rel=%.2e\n", + sx.max_abs, sx.mean_abs, sx.max_rel); + + // x_next is (latent + dt * v). dt is ~0.09, v drift ~1e-4 → x_next drift ~1e-5. Tolerances + // are roughly the same as the base DiT test since the Euler step doesn't amplify. + const float tol_abs = 1e-3f; + const float tol_rel = 5e-2f; + return sv.max_abs < tol_abs && sv.max_rel < tol_rel && + sx.max_abs < tol_abs && sx.max_rel < tol_rel; +} + +int main() { + const std::string ref_dir = "/tmp/ltx_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + // --- 1. Load the reference state dict. Weights are dumped with prefix "model.diffusion_model." + // which matches sd.cpp's default DiT location, so init_from_file with empty prefix passes names through. + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + loader.convert_tensors_name(); // no-op for LTX-2 — names already match + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // --- 2. Construct LTXRunner on CPU with explicit tiny-model params + // (the real LTX-2 hyperparams num_heads=32/head_dim=128 are auto-detected from weight shapes, + // but the tiny test uses num_heads=4/head_dim=32 which can't be inferred from q_norm alone). + LTX::LTXParams tiny_params; + tiny_params.in_channels = 16; + tiny_params.out_channels = 16; + tiny_params.inner_dim = 128; + tiny_params.num_heads = 4; + tiny_params.head_dim = 32; + tiny_params.num_layers = 2; + tiny_params.cross_attention_dim = 128; + tiny_params.cross_attention_adaln = false; + tiny_params.apply_gated_attention = false; + + ggml_backend_t backend = ggml_backend_cpu_init(); + LTX::LTXRunner runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model", VERSION_LTX2, &tiny_params); + runner.set_fps(24.0f); + // Parity dump uses simplified (f, h, w) positions without VAE scale factors or + // causal_fix — mirror that here so positions match the Python reference. + runner.set_scale_factors(1, 1, 1); + runner.set_causal_fix(false); + + const auto& p = runner.ltx_params; + std::printf("[config] layers=%d inner=%ld heads=%d head_dim=%d " + "in=%ld out=%ld ca_dim=%ld\n", + p.num_layers, p.inner_dim, p.num_heads, p.head_dim, + p.in_channels, p.out_channels, p.cross_attention_dim); + + // --- 3. Allocate & load weights into the GGML graph. + runner.alloc_params_buffer(); + std::map param_tensors; + runner.get_param_tensors(param_tensors, "model.diffusion_model"); + std::printf("[load] loading %zu param tensors…\n", param_tensors.size()); + if (!loader.load_tensors(param_tensors)) { + std::fprintf(stderr, "fatal: load_tensors failed (some names unmatched?)\n"); + return 1; + } + + // --- 4. Load reference inputs. + // latent_unflat is dumped as [C=16, F=2, H=4, W=6] (C outermost, W innermost in memory). + // LTXRunner::build_graph expects ggml ne=[W, H, T=F, C], so sd::Tensor shape is {6, 4, 2, 16} + // (sd shape[0] = innermost dim). Raw memory layout is identical. + auto latent = load_raw_bin(ref_dir + "/tensors/model__latent_unflat.bin", {6, 4, 2, 16}); + auto sigma_in = load_raw_bin(ref_dir + "/tensors/model__sigma.bin", {1}); + + // The C++ AdaLN now expects pre-scaled σ (see src/ltx.hpp:AdaLayerNormSingle docstring); + // the denoiser's sigma_to_t(σ)=σ*1000 will own this scaling in production. For the test + // we do it inline. + sd::Tensor timesteps({1}); + timesteps.data()[0] = sigma_in.data()[0] * 1000.0f; + + // context: Python shape [B=1, S=8, D=128] → ggml ne [128, 8, 1] → sd::Tensor shape {128, 8, 1}. + auto context = load_raw_bin(ref_dir + "/tensors/model__context_in.bin", {128, 8, 1}); + + sd::Tensor empty_mask; + + std::printf("[input] "); + print_shape("latent=", latent.shape()); + std::printf("[input] σ = %.6f → t = %.3f\n", sigma_in.data()[0], timesteps.data()[0]); + + // --- 5. Run forward. + std::printf("[compute] running single forward pass…\n"); + auto out = runner.compute(/*n_threads=*/1, latent, timesteps, context, empty_mask); + + print_shape("[output] out.shape = ", out.shape()); + + // Dump first & last few values to catch silent NaN / zeros before diffing. + std::printf("[output] first 8: "); + for (int i = 0; i < 8 && i < out.numel(); ++i) std::printf("%+.4f ", out.data()[i]); + std::printf("\n"); + std::printf("[output] last 8: "); + for (int64_t i = std::max(0, out.numel() - 8); i < out.numel(); ++i) std::printf("%+.4f ", out.data()[i]); + std::printf("\n"); + + // --- 6. Diff vs reference. + auto ref = load_raw_bin(ref_dir + "/tensors/model__velocity_out_unflat.bin", {6, 4, 2, 16}); + if (out.numel() != ref.numel()) { + std::fprintf(stderr, "fatal: element count mismatch cpp=%ld ref=%ld\n", + out.numel(), ref.numel()); + return 1; + } + std::printf("[ref] first 8: "); + for (int i = 0; i < 8 && i < ref.numel(); ++i) std::printf("%+.4f ", ref.data()[i]); + std::printf("\n"); + + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + std::printf("\n=== velocity_out parity ===\n"); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + s.max_abs, s.max_abs_idx, + s.max_abs_idx >= 0 ? out.data()[s.max_abs_idx] : 0.f, + s.max_abs_idx >= 0 ? ref.data()[s.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", s.mean_abs); + std::printf(" max_rel = %.3e\n", s.max_rel); + std::printf(" n = %ld\n\n", out.numel()); + + // FP32 tolerances realistic for multi-layer DiT: accumulation order (ggml's mat-mul vs + // torch.matmul), softmax + rope + rms_norm order-of-ops, and bf16 casts in flash-attn paths + // all add ~1e-4 abs / ~1e-2 rel drift per block. Mean_abs is the more stable indicator. + // + // max_rel is only meaningful when every |ref[i]| is comfortably above the expected noise + // floor. The V1 reference happens to contain a single element with |ref| ≈ 4e-5, so a + // 1e-5 abs drift (far below our max_abs tolerance) alone pushes max_rel to ~0.3. Skip + // the max_rel check here for the same reason V2-deep does — abs/mean catch real drift + // and the near-zero rel spike is noise. + const float tol_max_abs = 1e-3f; + const float tol_mean_abs = 2e-4f; + bool pass_dit = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_dit ? "DiT parity: PASS" : "DiT parity: FAIL", + tol_max_abs, tol_mean_abs); + + bool pass_sched = check_schedule(ref_dir); + std::printf("%s\n", pass_sched ? "Scheduler parity: PASS" : "Scheduler parity: FAIL"); + + bool pass_euler = check_euler_step(ref_dir, runner, latent, context); + std::printf("%s\n", pass_euler ? "Euler step parity: PASS" : "Euler step parity: FAIL"); + + // --- V2 parity (cross_attention_adaln=true + apply_gated_attention=true) ----------------- + // The V1 check above validates the base path with both V2 features disabled. The production + // 22B checkpoint uses both. This block reloads the same state_dict with a V2-flagged runner + // and compares against Python's `v2model/velocity_out_unflat` dump. + std::printf("\n=== V2 parity (cross_attention_adaln + apply_gated_attention) ===\n"); + LTX::LTXParams v2_params; + v2_params.in_channels = 16; + v2_params.out_channels = 16; + v2_params.inner_dim = 128; + v2_params.num_heads = 4; + v2_params.head_dim = 32; + v2_params.num_layers = 2; + v2_params.cross_attention_dim = 128; + v2_params.cross_attention_adaln = true; + v2_params.apply_gated_attention = true; + + LTX::LTXRunner v2_runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model_v2", VERSION_LTX2, &v2_params); + v2_runner.set_fps(24.0f); + v2_runner.set_scale_factors(1, 1, 1); + v2_runner.set_causal_fix(false); + v2_runner.alloc_params_buffer(); + + std::map v2_param_tensors; + v2_runner.get_param_tensors(v2_param_tensors, "model.diffusion_model_v2"); + std::printf("[v2] loading %zu param tensors under model.diffusion_model_v2\n", v2_param_tensors.size()); + if (!loader.load_tensors(v2_param_tensors)) { + std::fprintf(stderr, "fatal: V2 load_tensors failed\n"); + return 1; + } + + auto v2_latent = load_raw_bin(ref_dir + "/tensors/v2model__latent_unflat.bin", {6, 4, 2, 16}); + auto v2_sigma = load_raw_bin(ref_dir + "/tensors/v2model__sigma.bin", {1}); + sd::Tensor v2_timesteps({1}); + v2_timesteps.data()[0] = v2_sigma.data()[0] * 1000.0f; + auto v2_context = load_raw_bin(ref_dir + "/tensors/v2model__context_in.bin", {128, 8, 1}); + sd::Tensor v2_empty_mask; + + auto v2_out = v2_runner.compute(/*n_threads=*/1, v2_latent, v2_timesteps, v2_context, v2_empty_mask); + auto v2_ref = load_raw_bin(ref_dir + "/tensors/v2model__velocity_out_unflat.bin", {6, 4, 2, 16}); + + std::printf("[v2 output] first 8: "); + for (int i = 0; i < 8 && i < v2_out.numel(); ++i) std::printf("%+.4f ", v2_out.data()[i]); + std::printf("\n[v2 ref] first 8: "); + for (int i = 0; i < 8 && i < v2_ref.numel(); ++i) std::printf("%+.4f ", v2_ref.data()[i]); + std::printf("\n"); + + auto sv2 = diff_fp32(v2_out.data(), v2_ref.data(), v2_out.numel()); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + sv2.max_abs, sv2.max_abs_idx, + sv2.max_abs_idx >= 0 ? v2_out.data()[sv2.max_abs_idx] : 0.f, + sv2.max_abs_idx >= 0 ? v2_ref.data()[sv2.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", sv2.mean_abs); + std::printf(" max_rel = %.3e\n", sv2.max_rel); + + // Same max_rel skip as the V1 block above: the reference can contain a handful of + // near-zero elements whose tiny abs drift blows the relative error up without being + // a real parity regression. abs/mean catch actual drift. + bool pass_v2 = sv2.max_abs < tol_max_abs && sv2.mean_abs < tol_mean_abs; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_v2 ? "V2 DiT parity: PASS" : "V2 DiT parity: FAIL", + tol_max_abs, tol_mean_abs); + + // --- V2-deep parity: 8 layers + non-zero scale_shift_table ------------------------------- + // The V2 check above uses 2 layers with zeroed sst, so modulation is effectively identity + // and can hide sign/broadcast bugs in the (1+scale) and shift_kv/scale_kv branches. This + // block loads an 8-layer variant with randomised sst weights so any cross-layer drift in + // the V2 path surfaces. + std::printf("\n=== V2-deep parity (8 layers + non-zero scale_shift_table) ===\n"); + LTX::LTXParams v2_deep_params = v2_params; + v2_deep_params.num_layers = 8; + + LTX::LTXRunner v2_deep_runner(backend, /*offload_params_to_cpu=*/false, tsm, + "model.diffusion_model_v2_deep", VERSION_LTX2, &v2_deep_params); + v2_deep_runner.set_fps(24.0f); + v2_deep_runner.set_scale_factors(1, 1, 1); + v2_deep_runner.set_causal_fix(false); + v2_deep_runner.alloc_params_buffer(); + + std::map v2_deep_param_tensors; + v2_deep_runner.get_param_tensors(v2_deep_param_tensors, "model.diffusion_model_v2_deep"); + std::printf("[v2-deep] loading %zu param tensors\n", v2_deep_param_tensors.size()); + if (!loader.load_tensors(v2_deep_param_tensors)) { + std::fprintf(stderr, "fatal: V2-deep load_tensors failed\n"); + return 1; + } + + auto v2d_latent = load_raw_bin(ref_dir + "/tensors/v2deep__latent_unflat.bin", {6, 4, 2, 16}); + auto v2d_sigma = load_raw_bin(ref_dir + "/tensors/v2deep__sigma.bin", {1}); + sd::Tensor v2d_timesteps({1}); + v2d_timesteps.data()[0] = v2d_sigma.data()[0] * 1000.0f; + auto v2d_context = load_raw_bin(ref_dir + "/tensors/v2deep__context_in.bin", {128, 8, 1}); + sd::Tensor v2d_empty_mask; + + auto v2d_out = v2_deep_runner.compute(/*n_threads=*/1, v2d_latent, v2d_timesteps, v2d_context, v2d_empty_mask); + auto v2d_ref = load_raw_bin(ref_dir + "/tensors/v2deep__velocity_out_unflat.bin", {6, 4, 2, 16}); + + std::printf("[v2-deep output] first 8: "); + for (int i = 0; i < 8 && i < v2d_out.numel(); ++i) std::printf("%+.4f ", v2d_out.data()[i]); + std::printf("\n[v2-deep ref] first 8: "); + for (int i = 0; i < 8 && i < v2d_ref.numel(); ++i) std::printf("%+.4f ", v2d_ref.data()[i]); + std::printf("\n"); + + auto sv2d = diff_fp32(v2d_out.data(), v2d_ref.data(), v2d_out.numel()); + std::printf(" max_abs = %.3e (at index %ld: cpp=%.6f ref=%.6f)\n", + sv2d.max_abs, sv2d.max_abs_idx, + sv2d.max_abs_idx >= 0 ? v2d_out.data()[sv2d.max_abs_idx] : 0.f, + sv2d.max_abs_idx >= 0 ? v2d_ref.data()[sv2d.max_abs_idx] : 0.f); + std::printf(" mean_abs = %.3e\n", sv2d.mean_abs); + std::printf(" max_rel = %.3e\n", sv2d.max_rel); + + // Tolerance: max_rel is dropped here because per-element rel_err with b_i in the 1e-4 + // range produces meaningless blow-ups (100% rel for 1e-4 abs). max_abs and mean_abs are + // the reliable signals — both on the order of the 2-layer V2 test confirms no accumulated + // drift across 8 layers × non-zero sst modulation. + const float tol_max_abs_deep = 5e-3f; + const float tol_mean_abs_deep = 1e-3f; + bool pass_v2_deep = sv2d.max_abs < tol_max_abs_deep && sv2d.mean_abs < tol_mean_abs_deep; + std::printf("%s (tol: max_abs<%.1e mean_abs<%.1e; max_rel ignored due to near-zero divisors)\n", + pass_v2_deep ? "V2-deep DiT parity: PASS" : "V2-deep DiT parity: FAIL", + tol_max_abs_deep, tol_mean_abs_deep); + + bool pass = pass_dit && pass_sched && pass_euler && pass_v2 && pass_v2_deep; + std::printf("\n%s\n", pass ? "ALL PARITY: PASS" : "ALL PARITY: FAIL"); + return pass ? 0 : 3; +} diff --git a/tests/ltx_parity/test_mm_f32_parity.cpp b/tests/ltx_parity/test_mm_f32_parity.cpp new file mode 100644 index 000000000..bbe23668f --- /dev/null +++ b/tests/ltx_parity/test_mm_f32_parity.cpp @@ -0,0 +1,240 @@ +// Minimal F32 mul_mat parity test: builds a tiny graph with one mul_mat +// node, runs it on the CPU and CUDA backends with deterministic synthetic +// inputs, and diffs the outputs. Used to isolate where Gemma's attention +// kqv = mul_mat(v, softmax(kq)) drift comes from. +// +// Usage: +// sd-mm-f32-parity [shape: K M N B] # default 128 256 128 16 +// +// Where the matmul computed is: +// dst[i, j, b] = sum_k src0[k, i, b/r2] * src1[k, j, b] +// with src0=[K,M,B/r2], src1=[K,N,B] in ggml convention. +// +// The default shape matches Gemma 3 12B's V·softmax matmul at batch=128. + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#include "ggml.h" + +struct Shape { + int K; + int M; + int N; + int B; // ne[2] of src1 + int r2; // src1->ne[2] / src0->ne[2] +}; + +static std::vector run(ggml_backend_t backend, + const std::vector& src0_data, + const std::vector& src1_data, + const Shape& s, + bool prec_f32) { + // Allocate ggml context + struct ggml_init_params params = {}; + params.mem_size = 16 * 1024 * 1024; + params.mem_buffer = nullptr; + params.no_alloc = true; + ggml_context* ctx = ggml_init(params); + + const int B0 = s.B / s.r2; + auto src0 = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s.K, s.M, B0); + auto src1 = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s.K, s.N, s.B); + ggml_set_name(src0, "src0"); + ggml_set_name(src1, "src1"); + + auto dst = ggml_mul_mat(ctx, src0, src1); + if (prec_f32) ggml_mul_mat_set_prec(dst, GGML_PREC_F32); + ggml_set_name(dst, "dst"); + ggml_set_output(dst); + + // Build graph + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, dst); + + // Allocate backend buffer + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (buf == nullptr) { + std::fprintf(stderr, "fatal: ggml_backend_alloc_ctx_tensors failed\n"); + std::exit(1); + } + + // Upload inputs + ggml_backend_tensor_set(src0, src0_data.data(), 0, ggml_nbytes(src0)); + ggml_backend_tensor_set(src1, src1_data.data(), 0, ggml_nbytes(src1)); + + // Allocate compute buffer + ggml_gallocr_t gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(gallocr, gf)) { + std::fprintf(stderr, "fatal: ggml_gallocr_alloc_graph failed\n"); + std::exit(1); + } + + // Compute + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "fatal: graph compute failed\n"); + std::exit(1); + } + + // Read output + std::vector out(ggml_nelements(dst)); + ggml_backend_tensor_get(dst, out.data(), 0, ggml_nbytes(dst)); + + ggml_gallocr_free(gallocr); + ggml_backend_buffer_free(buf); + ggml_free(ctx); + return out; +} + +int main(int argc, char** argv) { + Shape s{128, 256, 128, 16, 2}; + if (argc >= 5) { + s.K = std::atoi(argv[1]); + s.M = std::atoi(argv[2]); + s.N = std::atoi(argv[3]); + s.B = std::atoi(argv[4]); + } + if (argc >= 6) s.r2 = std::atoi(argv[5]); + + std::printf("Shape: K=%d M=%d N=%d B=%d r2=%d (src0=[%d,%d,%d], src1=[%d,%d,%d])\n", + s.K, s.M, s.N, s.B, s.r2, + s.K, s.M, s.B / s.r2, s.K, s.N, s.B); + + // Deterministic inputs. Default = small magnitudes; SD_MM_DIST=gemma uses + // wider v values and proper softmax (heavy-tailed) distribution to mirror + // Gemma 3 12B attention's V·softmax matmul. + std::mt19937 rng(42); + std::normal_distribution dist0(0.0f, 1.0f); + std::normal_distribution dist1(0.0f, 0.05f); + + const int n0 = s.K * s.M * (s.B / s.r2); + const int n1 = s.K * s.N * s.B; + std::vector src0(n0), src1(n1); + + const char* dist_mode = std::getenv("SD_MM_DIST"); + const char* load_dir = std::getenv("SD_MM_LOAD"); + if (load_dir != nullptr) { + // Load src0 = "_attn_v.bin" and src1 = "_attn_softmax.bin" from disk. + // Override shape from the .shape file alongside. + auto load_bin = [](const std::string& path, std::vector& out) { + FILE* fp = std::fopen(path.c_str(), "rb"); + if (!fp) { std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); std::exit(1); } + std::fseek(fp, 0, SEEK_END); + long sz = std::ftell(fp); + std::fseek(fp, 0, SEEK_SET); + out.resize(sz / sizeof(float)); + std::fread(out.data(), 1, sz, fp); + std::fclose(fp); + }; + auto load_shape = [](const std::string& path) { + FILE* fp = std::fopen(path.c_str(), "r"); + if (!fp) { std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); std::exit(1); } + long ne[4]; char tname[32]; + std::fscanf(fp, "%ld %ld %ld %ld %31s", &ne[0], &ne[1], &ne[2], &ne[3], tname); + std::fclose(fp); + return std::vector{ne[0], ne[1], ne[2], ne[3]}; + }; + std::string base = load_dir; + load_bin(base + "/_attn_v.bin", src0); + load_bin(base + "/_attn_softmax.bin", src1); + auto sh0 = load_shape(base + "/_attn_v.shape"); + auto sh1 = load_shape(base + "/_attn_softmax.shape"); + // src0 shape is [K, M, B/r2], src1 shape is [K, N, B]. From dumps: + // _attn_v shape [128, 256, 8, 1] → K=128, M=256, B/r2=8 + // _attn_softmax shape [128, 128, 16, 1] → K=128, N=128, B=16 + s.K = (int) sh0[0]; + s.M = (int) sh0[1]; + s.N = (int) sh1[1]; + s.B = (int) sh1[2]; + s.r2 = s.B / (int) sh0[2]; + std::printf("[load] src0 shape=[%ld,%ld,%ld] from %s/_attn_v.bin\n", sh0[0], sh0[1], sh0[2], load_dir); + std::printf("[load] src1 shape=[%ld,%ld,%ld] from %s/_attn_softmax.bin\n", sh1[0], sh1[1], sh1[2], load_dir); + } else if (dist_mode && std::strcmp(dist_mode, "gemma") == 0) { + // v values: ~N(0, 50) like Gemma's v_proj output (mean_mag ~ 54) + std::normal_distribution dv(0.0f, 50.0f); + for (int i = 0; i < n0; ++i) src0[i] = dv(rng); + // softmax weights: per row of K, exp normalised. Sums to 1 per row. + for (int b = 0; b < s.B; ++b) { + for (int n = 0; n < s.N; ++n) { + float* row = &src1[b * s.N * s.K + n * s.K]; + float max = -1e30f; + for (int k = 0; k < s.K; ++k) { + row[k] = std::normal_distribution(0.0f, 5.0f)(rng); + if (row[k] > max) max = row[k]; + } + float sum = 0.0f; + for (int k = 0; k < s.K; ++k) { + row[k] = std::exp(row[k] - max); + sum += row[k]; + } + for (int k = 0; k < s.K; ++k) row[k] /= sum; + } + } + } else { + for (int i = 0; i < n0; ++i) src0[i] = dist0(rng); + for (int i = 0; i < n1; ++i) src1[i] = dist1(rng); + } + + // CPU + std::printf("Running CPU forward...\n"); + ggml_backend_t cpu_backend = ggml_backend_cpu_init(); + auto cpu_out = run(cpu_backend, src0, src1, s, /*prec_f32=*/true); + ggml_backend_free(cpu_backend); + + // CUDA +#ifdef SD_USE_CUDA + int cuda_device = 1; + if (const char* d = std::getenv("SD_CUDA_DEVICE")) cuda_device = std::atoi(d); + std::printf("Running CUDA forward (device %d)...\n", cuda_device); + ggml_backend_t cuda_backend = ggml_backend_cuda_init(cuda_device); + if (!cuda_backend) { + std::fprintf(stderr, "fatal: CUDA backend init failed for device %d\n", cuda_device); + return 1; + } + auto cuda_out = run(cuda_backend, src0, src1, s, /*prec_f32=*/true); + ggml_backend_free(cuda_backend); +#else + std::fprintf(stderr, "fatal: built without SD_USE_CUDA\n"); + return 1; +#endif + + // Diff + if (cpu_out.size() != cuda_out.size()) { + std::fprintf(stderr, "fatal: output size mismatch\n"); + return 1; + } + double max_abs = 0.0, sum_abs = 0.0, sum_cpu_abs = 0.0; + int argmax = 0; + for (size_t i = 0; i < cpu_out.size(); ++i) { + double diff = std::fabs((double) cpu_out[i] - (double) cuda_out[i]); + if (diff > max_abs) { + max_abs = diff; + argmax = (int) i; + } + sum_abs += diff; + sum_cpu_abs += std::fabs(cpu_out[i]); + } + const double mean_abs = sum_abs / cpu_out.size(); + const double cpu_mag = sum_cpu_abs / cpu_out.size(); + std::printf("Diff: max_abs=%.6e mean_abs=%.6e cpu_mean_mag=%.6e (rel_max=%.3e rel_mean=%.3e)\n", + max_abs, mean_abs, cpu_mag, + cpu_mag > 0 ? max_abs / cpu_mag : 0.0, + cpu_mag > 0 ? mean_abs / cpu_mag : 0.0); + std::printf("Argmax element [idx %d]: CPU=%+.9e CUDA=%+.9e\n", argmax, cpu_out[argmax], cuda_out[argmax]); + std::printf("First 6 elements:\n"); + for (int i = 0; i < 6 && i < (int) cpu_out.size(); ++i) { + std::printf(" [%2d] CPU=%+.9e CUDA=%+.9e diff=%+.3e\n", + i, cpu_out[i], cuda_out[i], cuda_out[i] - cpu_out[i]); + } + + return 0; +} diff --git a/tests/ltx_parity/test_s2d_primitives.cpp b/tests/ltx_parity/test_s2d_primitives.cpp new file mode 100644 index 000000000..a2758ffd1 --- /dev/null +++ b/tests/ltx_parity/test_s2d_primitives.cpp @@ -0,0 +1,185 @@ +// Standalone test: verify our axis-W / axis-H / axis-T SpaceToDepth and +// DepthToSpace ggml recipes against Python einops `rearrange(...)` outputs +// dumped by dump_s2d.py. Composition tests cover the stride patterns used +// by the VAE: (2,2,2), (1,2,2), (2,1,1). + +#include +#include +#include +#include +#include +#include + +#include "ggml-cpu.h" +#include "ggml.h" +#include "ltxvae_primitives.hpp" + +namespace { + +constexpr int B = 1; +constexpr int C = 3; +constexpr int T = 4; +constexpr int H = 6; +constexpr int W = 8; +constexpr int FACTOR = 2; + +std::vector load_bin(const std::string& path, size_t expected_numel) { + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { std::fprintf(stderr, "cannot open %s\n", path.c_str()); std::exit(2); } + std::vector buf(expected_numel); + f.read(reinterpret_cast(buf.data()), expected_numel * sizeof(float)); + if (!f.good()) { std::fprintf(stderr, "short read on %s\n", path.c_str()); std::exit(2); } + return buf; +} + +enum Kind { S2D_W, S2D_H, S2D_T, S2D_222, S2D_122, S2D_211, + D2S_W, D2S_H, D2S_T, D2S_222, D2S_122, D2S_211, + PIXEL_NORM, PCS_NORMALIZE, PCS_UNNORMALIZE }; + +struct CaseSpec { + const char* name; + std::vector in_shape_ne; + std::vector expected_shape_ne; + Kind kind; +}; + +bool run_case(const CaseSpec& cs, const std::string& ref_dir) { + size_t in_numel = 1, out_numel = 1; + for (auto d : cs.in_shape_ne) in_numel *= d; + for (auto d : cs.expected_shape_ne) out_numel *= d; + + const bool is_d2s = (cs.kind >= D2S_W && cs.kind <= D2S_211); + const bool is_pn = (cs.kind == PIXEL_NORM); + const bool is_pcs = (cs.kind == PCS_NORMALIZE || cs.kind == PCS_UNNORMALIZE); + std::string in_file, exp_file; + if (is_pn) { + in_file = ref_dir + "/tensors/pn_input.bin"; + exp_file = ref_dir + "/tensors/pn_expected.bin"; + } else if (is_pcs) { + in_file = ref_dir + "/tensors/pcs_input.bin"; + exp_file = ref_dir + "/tensors/" + + std::string(cs.kind == PCS_NORMALIZE ? "pcs_normalize_expected.bin" + : "pcs_unnormalize_expected.bin"); + } else { + in_file = ref_dir + "/tensors/" + (is_d2s ? "dinput_" : "input_") + cs.name + ".bin"; + exp_file = ref_dir + "/tensors/" + (is_d2s ? "dexpected_" : "expected_") + cs.name + ".bin"; + } + auto in_data = load_bin(in_file, in_numel); + auto expected = load_bin(exp_file, out_numel); + + size_t mem_size = 128 * 1024 * 1024; + std::vector mem_buf(mem_size); + ggml_init_params params{mem_size, mem_buf.data(), false}; + ggml_context* ctx = ggml_init(params); + + ggml_tensor* x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, + cs.in_shape_ne[0], cs.in_shape_ne[1], + cs.in_shape_ne[2], cs.in_shape_ne[3]); + std::memcpy(x->data, in_data.data(), in_numel * sizeof(float)); + + ggml_tensor* y = nullptr; + ggml_tensor* mu_t = nullptr; + ggml_tensor* sigma_t = nullptr; + std::vector mu_data, sigma_data; + switch (cs.kind) { + case S2D_W: y = LTXVAE::space_to_depth_axisW(ctx, x, FACTOR); break; + case S2D_H: y = LTXVAE::space_to_depth_axisH(ctx, x, FACTOR); break; + case S2D_T: y = LTXVAE::space_to_depth_axisT(ctx, x, FACTOR); break; + case S2D_222: y = LTXVAE::space_to_depth(ctx, x, FACTOR, FACTOR, FACTOR); break; + case S2D_122: y = LTXVAE::space_to_depth(ctx, x, 1, FACTOR, FACTOR); break; + case S2D_211: y = LTXVAE::space_to_depth(ctx, x, FACTOR, 1, 1); break; + case D2S_W: y = LTXVAE::depth_to_space_axisW(ctx, x, FACTOR); break; + case D2S_H: y = LTXVAE::depth_to_space_axisH(ctx, x, FACTOR); break; + case D2S_T: y = LTXVAE::depth_to_space_axisT(ctx, x, FACTOR); break; + case D2S_222: y = LTXVAE::depth_to_space(ctx, x, FACTOR, FACTOR, FACTOR); break; + case D2S_122: y = LTXVAE::depth_to_space(ctx, x, 1, FACTOR, FACTOR); break; + case D2S_211: y = LTXVAE::depth_to_space(ctx, x, FACTOR, 1, 1); break; + case PIXEL_NORM: y = LTXVAE::pixel_norm(ctx, x, 1e-8f); break; + case PCS_NORMALIZE: + case PCS_UNNORMALIZE: { + int64_t C = cs.in_shape_ne[3]; + mu_data = load_bin(ref_dir + "/tensors/pcs_mu.bin", (size_t)C); + sigma_data = load_bin(ref_dir + "/tensors/pcs_sigma.bin", (size_t)C); + mu_t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, C); + sigma_t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, C); + std::memcpy(mu_t->data, mu_data.data(), C * sizeof(float)); + std::memcpy(sigma_t->data, sigma_data.data(), C * sizeof(float)); + y = (cs.kind == PCS_NORMALIZE) + ? LTXVAE::pcs_normalize(ctx, x, mu_t, sigma_t) + : LTXVAE::pcs_unnormalize(ctx, x, mu_t, sigma_t); + } break; + } + + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, y); + ggml_graph_compute_with_ctx(ctx, gf, 1); + + bool shape_ok = true; + for (int i = 0; i < 4; i++) if (y->ne[i] != cs.expected_shape_ne[i]) { shape_ok = false; break; } + if (!shape_ok) { + std::printf(" %-18s SHAPE_FAIL got=[%lld,%lld,%lld,%lld] exp=[%lld,%lld,%lld,%lld]\n", + cs.name, + (long long)y->ne[0], (long long)y->ne[1], (long long)y->ne[2], (long long)y->ne[3], + (long long)cs.expected_shape_ne[0], (long long)cs.expected_shape_ne[1], + (long long)cs.expected_shape_ne[2], (long long)cs.expected_shape_ne[3]); + ggml_free(ctx); + return false; + } + + const float* got = (const float*)y->data; + float max_abs = 0.f; + int64_t first_diff = -1; + for (size_t i = 0; i < out_numel; i++) { + float d = std::abs(got[i] - expected[i]); + if (d > max_abs) { max_abs = d; if (first_diff < 0) first_diff = (int64_t)i; } + } + // PixelNorm / PCS involve f32 divides & rms; relax tolerance slightly. + float tol = (cs.kind >= PIXEL_NORM) ? 5e-6f : 1e-6f; + bool pass = max_abs < tol; + std::printf(" %-18s %s max_abs=%.3e", cs.name, pass ? "PASS" : "FAIL", max_abs); + if (!pass && first_diff >= 0) { + std::printf(" first_diff_idx=%lld got=%.6f exp=%.6f", + (long long)first_diff, got[first_diff], expected[first_diff]); + } + std::printf("\n"); + + ggml_free(ctx); + return pass; +} + +} // namespace + +int main() { + const std::string ref_dir = "/tmp/s2d_ref"; + + std::vector cases = { + // SpaceToDepth + {"axisW", {W*FACTOR, H, T, C}, {W, H, T, C*FACTOR}, S2D_W}, + {"axisH", {W, H*FACTOR, T, C}, {W, H, T, C*FACTOR}, S2D_H}, + {"axisT", {W, H, T*FACTOR, C}, {W, H, T, C*FACTOR}, S2D_T}, + {"full222", {W*FACTOR, H*FACTOR, T*FACTOR, C}, {W, H, T, C*8}, S2D_222}, + {"full122", {W*FACTOR, H*FACTOR, T, C}, {W, H, T, C*4}, S2D_122}, + {"full211", {W, H, T*FACTOR, C}, {W, H, T, C*2}, S2D_211}, + // DepthToSpace (input has extra channels) + {"axisW", {W, H, T, C*FACTOR}, {W*FACTOR, H, T, C}, D2S_W}, + {"axisH", {W, H, T, C*FACTOR}, {W, H*FACTOR, T, C}, D2S_H}, + {"axisT", {W, H, T, C*FACTOR}, {W, H, T*FACTOR, C}, D2S_T}, + {"full222", {W, H, T, C*8}, {W*FACTOR, H*FACTOR, T*FACTOR, C}, D2S_222}, + {"full122", {W, H, T, C*4}, {W*FACTOR, H*FACTOR, T, C}, D2S_122}, + {"full211", {W, H, T, C*2}, {W, H, T*FACTOR, C}, D2S_211}, + // PixelNorm (dim=channel) and PerChannelStatistics + {"pn", {W, H, T, 5}, {W, H, T, 5}, PIXEL_NORM}, + {"pcs_norm", {W, H, T, 6}, {W, H, T, 6}, PCS_NORMALIZE}, + {"pcs_unnorm", {W, H, T, 6}, {W, H, T, 6}, PCS_UNNORMALIZE}, + }; + + std::printf("SpaceToDepth primitive parity:\n"); + int pass = 0; + for (size_t i = 0; i < cases.size(); i++) { + if (i == 6) std::printf("\nDepthToSpace primitive parity:\n"); + if (i == 12) std::printf("\nNorm primitives parity:\n"); + if (run_case(cases[i], ref_dir)) pass++; + } + std::printf("\n%d / %zu cases passed.\n", pass, cases.size()); + return (pass == (int)cases.size()) ? 0 : 3; +} diff --git a/tests/ltx_parity/test_softmax_parity.cpp b/tests/ltx_parity/test_softmax_parity.cpp new file mode 100644 index 000000000..89d894bc3 --- /dev/null +++ b/tests/ltx_parity/test_softmax_parity.cpp @@ -0,0 +1,120 @@ +// Standalone softmax parity test on the same inputs the Gemma attention +// path computes for kq just before softmax. Loads _attn_kq_masked.bin +// from a CPU dump (or generates synthetic), runs ggml_soft_max on CPU and +// CUDA backends, diffs. + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#ifdef SD_USE_CUDA +#include "ggml-cuda.h" +#endif +#include "ggml.h" + +static std::vector run(ggml_backend_t backend, + const std::vector& src_data, + int K, int N, int B) { + struct ggml_init_params params = {}; + params.mem_size = 16 * 1024 * 1024; + params.mem_buffer = nullptr; + params.no_alloc = true; + ggml_context* ctx = ggml_init(params); + + auto src = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, K, N, B); + ggml_set_name(src, "src"); + + auto dst = ggml_soft_max(ctx, src); + ggml_set_name(dst, "dst"); + ggml_set_output(dst); + + ggml_cgraph* gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, dst); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (buf == nullptr) { std::fprintf(stderr, "fatal: alloc failed\n"); std::exit(1); } + ggml_backend_tensor_set(src, src_data.data(), 0, ggml_nbytes(src)); + + ggml_gallocr_t gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!ggml_gallocr_alloc_graph(gallocr, gf)) { std::fprintf(stderr, "fatal: alloc graph\n"); std::exit(1); } + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, "fatal: compute failed\n"); std::exit(1); + } + std::vector out(ggml_nelements(dst)); + ggml_backend_tensor_get(dst, out.data(), 0, ggml_nbytes(dst)); + + ggml_gallocr_free(gallocr); + ggml_backend_buffer_free(buf); + ggml_free(ctx); + return out; +} + +int main(int argc, char** argv) { + int K = 128, N = 128, B = 16; + std::vector src; + + const char* load_path = std::getenv("SD_SOFTMAX_LOAD"); + if (load_path) { + FILE* fp = std::fopen(load_path, "rb"); + if (!fp) { std::fprintf(stderr, "fatal: cannot open %s\n", load_path); return 1; } + std::fseek(fp, 0, SEEK_END); + long sz = std::ftell(fp); + std::fseek(fp, 0, SEEK_SET); + src.resize(sz / sizeof(float)); + std::fread(src.data(), 1, sz, fp); + std::fclose(fp); + std::printf("[load] read %ld bytes from %s\n", sz, load_path); + // Need to also know shape. Hardcode kq shape for now. + K = 128; N = 128; B = 16; + if ((int)(K * N * B) != (int) src.size()) { + std::fprintf(stderr, "size mismatch: K*N*B=%d but file has %zu floats\n", + K * N * B, src.size()); + return 1; + } + } else { + // Synthetic kq-like values: roughly [-15, 15] like scaled attention logits + std::mt19937 rng(42); + std::normal_distribution dist(0.0f, 5.0f); + src.resize(K * N * B); + for (auto& x : src) x = dist(rng); + } + + std::printf("Shape: K=%d N=%d B=%d\n", K, N, B); + + std::printf("Running CPU forward...\n"); + ggml_backend_t cpu_backend = ggml_backend_cpu_init(); + auto cpu_out = run(cpu_backend, src, K, N, B); + ggml_backend_free(cpu_backend); + +#ifdef SD_USE_CUDA + int cuda_device = 1; + if (const char* d = std::getenv("SD_CUDA_DEVICE")) cuda_device = std::atoi(d); + std::printf("Running CUDA forward (device %d)...\n", cuda_device); + ggml_backend_t cuda_backend = ggml_backend_cuda_init(cuda_device); + if (!cuda_backend) { std::fprintf(stderr, "fatal: CUDA init failed\n"); return 1; } + auto cuda_out = run(cuda_backend, src, K, N, B); + ggml_backend_free(cuda_backend); +#endif + + if (cpu_out.size() != cuda_out.size()) { std::fprintf(stderr, "size mismatch\n"); return 1; } + + double max_abs = 0.0, sum_abs = 0.0, sum_cpu_abs = 0.0; + int argmax = 0, n_diff = 0; + for (size_t i = 0; i < cpu_out.size(); ++i) { + double diff = std::fabs((double) cpu_out[i] - (double) cuda_out[i]); + if (diff > 0) n_diff++; + if (diff > max_abs) { max_abs = diff; argmax = (int) i; } + sum_abs += diff; + sum_cpu_abs += std::fabs(cpu_out[i]); + } + std::printf("Diff: max=%.6e mean=%.6e cpu_mean_mag=%.6e n_diff=%d/%zu\n", + max_abs, sum_abs / cpu_out.size(), sum_cpu_abs / cpu_out.size(), + n_diff, cpu_out.size()); + std::printf("Argmax [idx %d]: CPU=%+.9e CUDA=%+.9e\n", argmax, cpu_out[argmax], cuda_out[argmax]); + return 0; +} diff --git a/tests/ltx_parity/test_vae_parity.cpp b/tests/ltx_parity/test_vae_parity.cpp new file mode 100644 index 000000000..f4005683d --- /dev/null +++ b/tests/ltx_parity/test_vae_parity.cpp @@ -0,0 +1,378 @@ +// LTX-2 VAE C++ parity test. +// +// Loads /tmp/vae_ref/state_dict.safetensors (from dump_vae.py) plus the per-stage +// reference trace tensors, runs our C++ VideoEncoder + VideoDecoder, and diffs +// each stage against the Python reference. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ltxvae.hpp" + +// Standalone GGMLRunner that wraps a single LTXVAE::TimestepEmbedder block so we can +// isolate the sinusoidal + 2-linear path from the full VAE pipeline. +struct TERunner : public GGMLRunner { + LTXVAE::TimestepEmbedder te; + + TERunner(ggml_backend_t backend, bool offload, const String2TensorStorage& tsm, + const std::string& prefix, int embedding_dim) + : GGMLRunner(backend, offload), te(embedding_dim) { + te.init(params_ctx, tsm, prefix); + } + std::string get_desc() override { return "ltx2_vae_te_probe"; } + void get_param_tensors(std::map& tensors, const std::string& prefix) { + te.get_param_tensors(tensors, prefix); + } + sd::Tensor compute(int n_threads, const sd::Tensor& timestep) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* t = make_input(timestep); + auto runner_ctx = get_context(); + auto out = te.forward(&runner_ctx, t); + ggml_build_forward_expand(gf, out); + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } +}; + +// Runs JUST the ada-values reshape+slice on a pre-computed time_embed. Returns one of +// the 4 slices (chosen by `which` in 0..3 → shift1, scale1, shift2, scale2). This +// isolates the PyTorch `timestep.reshape(B, 4, -1, 1, 1, 1)` → unbind(dim=1) path +// in pure GGML ops to verify memory-order correctness. +struct ShiftProbeRunner : public GGMLRunner { + int in_channels; + int which; + ShiftProbeRunner(ggml_backend_t backend, bool offload, int in_ch, int which) + : GGMLRunner(backend, offload), in_channels(in_ch), which(which) {} + std::string get_desc() override { return "ltx2_vae_shift_probe"; } + sd::Tensor compute(int n_threads, const sd::Tensor& time_embed) { + auto get_g = [&]() -> ggml_cgraph* { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* te = make_input(time_embed); // ne=[4*C, 1] + auto re = ggml_reshape_2d(compute_ctx, te, in_channels, 4); // [C, 4] + auto out = ggml_ext_slice(compute_ctx, re, 1, which, which + 1); // [C, 1] + out = ggml_cont(compute_ctx, out); + ggml_build_forward_expand(gf, out); + return gf; + }; + return take_or_empty(GGMLRunner::compute(get_g, n_threads, true)); + } +}; +#include "model.h" +#include "tensor.hpp" + +namespace { + +sd::Tensor load_raw_bin(const std::string& path, const std::vector& shape) { + sd::Tensor t(shape); + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) { + std::fprintf(stderr, "fatal: cannot open %s\n", path.c_str()); + std::exit(2); + } + f.read(reinterpret_cast(t.data()), + static_cast(t.numel() * sizeof(float))); + if (!f.good()) { + std::fprintf(stderr, "fatal: short read on %s (expected %ld floats)\n", + path.c_str(), t.numel()); + std::exit(2); + } + return t; +} + +struct DiffStats { + float max_abs = 0.f, mean_abs = 0.f, max_rel = 0.f; + int64_t max_abs_idx = -1; +}; + +DiffStats diff_fp32(const float* a, const float* b, int64_t n) { + DiffStats s; + double sum_abs = 0.0; + for (int64_t i = 0; i < n; ++i) { + float abs_err = std::fabs(a[i] - b[i]); + float rel_err = abs_err / (std::fabs(b[i]) + 1e-8f); + if (abs_err > s.max_abs) { s.max_abs = abs_err; s.max_abs_idx = i; } + s.max_rel = std::max(s.max_rel, rel_err); + sum_abs += abs_err; + } + s.mean_abs = static_cast(sum_abs / (n > 0 ? n : 1)); + return s; +} + +bool compare(const std::string& tag, const sd::Tensor& got, + const std::string& ref_path, const std::vector& ref_shape, + float tol_max_abs, float tol_mean_abs) { + auto ref = load_raw_bin(ref_path, ref_shape); + if (got.numel() != ref.numel()) { + std::printf(" %-20s SHAPE_FAIL got_numel=%ld ref_numel=%ld\n", + tag.c_str(), got.numel(), ref.numel()); + return false; + } + auto s = diff_fp32(got.data(), ref.data(), got.numel()); + bool pass = s.max_abs < tol_max_abs && s.mean_abs < tol_mean_abs; + std::printf(" %-20s %s max_abs=%.3e mean_abs=%.3e n=%ld\n", + tag.c_str(), pass ? "PASS" : "FAIL", s.max_abs, s.mean_abs, got.numel()); + return pass; +} + +} // namespace + +int main() { + // Enable library logging so load_tensors shape mismatches surface on stderr. + sd_set_log_callback( + [](enum sd_log_level_t /*level*/, const char* text, void* /*data*/) { + std::fputs(text, stderr); + }, + nullptr); + + const std::string ref_dir = "/tmp/vae_ref"; + const std::string state_path = ref_dir + "/state_dict.safetensors"; + + ModelLoader loader; + if (!loader.init_from_file(state_path)) { + std::fprintf(stderr, "fatal: init_from_file failed for %s\n", state_path.c_str()); + return 1; + } + const auto& tsm = loader.get_tensor_storage_map(); + std::printf("[state_dict] loaded %zu tensors from %s\n", tsm.size(), state_path.c_str()); + + // Tiny config from dump_vae.py: in=3, latent=8, patch=2, base_ch=8. + // Encoder: compress_space_res(×2), compress_time_res(×2), res_x(1 layer). + // Decoder: compress_space(m=1), compress_time(m=1), res_x(1 layer, timestep_cond=True). + const int in_ch = 3; + const int latent_ch = 8; + const int base_ch = 8; + const int patch = 2; + const int B = 1, F = 9, H = 16, W_ = 16; + + std::vector enc_specs = { + {LTXVAE::EncoderBlockKind::COMPRESS_SPACE_RES, 1, 2}, + {LTXVAE::EncoderBlockKind::COMPRESS_TIME_RES, 1, 2}, + {LTXVAE::EncoderBlockKind::RES_X, 1, 1}, + }; + std::vector dec_specs = { + {LTXVAE::DecoderBlockKind::COMPRESS_SPACE, 1, 1}, + {LTXVAE::DecoderBlockKind::COMPRESS_TIME, 1, 1}, + {LTXVAE::DecoderBlockKind::RES_X, 1, 1}, + }; + + ggml_backend_t backend = ggml_backend_cpu_init(); + + // --- Encoder --- + LTXVAE::VAEEncoderRunner enc_runner(backend, /*offload=*/false, tsm, + /*prefix=*/"vae.encoder", + in_ch, latent_ch, patch, enc_specs); + enc_runner.alloc_params_buffer(); + std::map enc_params; + enc_runner.get_param_tensors(enc_params, "vae.encoder"); + std::printf("[enc] requesting %zu param tensors\n", enc_params.size()); + if (!loader.load_tensors(enc_params)) { + std::fprintf(stderr, "fatal: encoder load_tensors failed\n"); + return 1; + } + + // Load video input. Python shape (1, 3, 9, 16, 16) → GGML ne=[W=16, H=16, T=9, C=3]. + auto video_in = load_raw_bin(ref_dir + "/tensors/video_in.bin", {W_, H, F, in_ch}); + std::printf("[enc] running encoder (traced)\n"); + + bool pass = true; + struct Stage { int idx; const char* name; std::vector shape; float abs_tol, mean_tol; }; + // Dump order & shapes (PyTorch-majored): + // 0 post_patchify (1,12,9,8,8) → ne=[8,8,9,12] + // 1 post_conv_in (1,8,9,8,8) → ne=[8,8,9,8] + // 2 down_block[0] (cs) (1,16,9,4,4) → ne=[4,4,9,16] + // 3 down_block[1] (ct) (1,32,5,4,4) → ne=[4,4,5,32] + // 4 down_block[2] (res)(1,32,5,4,4) → ne=[4,4,5,32] + // 5 post_norm (1,32,5,4,4) → ne=[4,4,5,32] + // 6 post_conv_out (1,9,5,4,4) → ne=[4,4,5,9] + // 7 means_preNorm (1,8,5,4,4) → ne=[4,4,5,8] + // 8 latent (1,8,5,4,4) → ne=[4,4,5,8] + // Conv3d weights are stored f16 in the block — each conv boundary introduces a + // fp16-quantization step (~1e-3 abs per layer). Tolerances are set accordingly. + std::vector stages = { + {0, "enc_post_patchify", {8, 8, F, 12}, 1e-6f, 1e-7f}, // pure rearrange + {1, "enc_post_conv_in", {8, 8, F, 8}, 2e-3f, 3e-4f}, + {2, "enc_block_0", {4, 4, F, 16}, 3e-3f, 5e-4f}, + {3, "enc_block_1", {4, 4, 5, 32}, 5e-3f, 8e-4f}, + {4, "enc_block_2", {4, 4, 5, 32}, 5e-3f, 1e-3f}, + {5, "enc_post_norm", {4, 4, 5, 32}, 5e-3f, 1e-3f}, + {6, "enc_post_conv_out", {4, 4, 5, 9}, 5e-3f, 1e-3f}, + {7, "enc_means_preNorm", {4, 4, 5, 8}, 5e-3f, 1e-3f}, + {8, "latent", {4, 4, 5, 8}, 5e-3f, 1e-3f}, + }; + for (const auto& s : stages) { + auto got = enc_runner.compute(1, video_in, s.idx); + pass &= compare(s.name, got, ref_dir + "/tensors/" + s.name + ".bin", s.shape, + s.abs_tol, s.mean_tol); + } + + // --- Decoder --- + LTXVAE::VAEDecoderRunner dec_runner(backend, /*offload=*/false, tsm, + /*prefix=*/"vae.decoder", + latent_ch, in_ch, patch, base_ch, + /*timestep_cond=*/true, dec_specs); + dec_runner.alloc_params_buffer(); + std::map dec_params; + dec_runner.get_param_tensors(dec_params, "vae.decoder"); + std::printf("[dec] requesting %zu param tensors\n", dec_params.size()); + + // Diagnose any name/shape mismatches. + std::set file_keys; + for (const auto& kv : tsm) file_keys.insert(kv.first); + int missing = 0; + for (const auto& pt : dec_params) { + auto it = file_keys.find(pt.first); + if (it == file_keys.end()) { + if (missing < 10) std::printf("[dec] missing: %s\n", pt.first.c_str()); + missing++; + } + } + std::printf("[dec] %d / %zu tensors missing from file\n", missing, dec_params.size()); + + if (!loader.load_tensors(dec_params)) { + std::fprintf(stderr, "fatal: decoder load_tensors failed\n"); + return 1; + } + + // Feed the Python reference latent to the decoder so its diffs are independent of + // encoder errors. Once encoder parity is green we can chain them. + auto latent_ref = load_raw_bin(ref_dir + "/tensors/latent.bin", {4, 4, 5, latent_ch}); + sd::Tensor timestep_t({1}); + timestep_t.data()[0] = 0.05f; + // TimestepEmbedder micro-probe: bypass the full decoder and run just the + // up_blocks[0].time_embedder on timestep=0.05 to verify the sinusoidal + linear path. + { + TERunner te_runner(backend, false, tsm, "vae.decoder.up_blocks.0.time_embedder", 256); + te_runner.alloc_params_buffer(); + std::map te_params; + te_runner.get_param_tensors(te_params, "vae.decoder.up_blocks.0.time_embedder"); + if (!loader.load_tensors(te_params)) { + std::fprintf(stderr, "fatal: TE load failed\n"); + return 1; + } + auto te_out = te_runner.compute(1, timestep_t); + // Python dumps shape [B=1, 256] → innermost 256. sd::Tensor stores innermost-first, + // so shape is {256, 1}. Same numel. + pass &= compare("TimestepEmbedder", te_out, ref_dir + "/tensors/te_probe_up0.bin", + {256, 1}, 1e-4f, 1e-5f); + + // Verify the ada-values reshape+slice: Python does `te.reshape(B,4,-1,1,1,1)` → + // unbind(dim=1). The four unbound slices should be te[0:64], te[64:128], te[128:192], + // te[192:256]. Run each slice through the C++ reshape+slice path and byte-compare. + auto te_ref = load_raw_bin(ref_dir + "/tensors/te_probe_up0.bin", {256, 1}); + const char* which_names[] = {"shift1", "scale1", "shift2", "scale2"}; + for (int w = 0; w < 4; w++) { + ShiftProbeRunner sp(backend, false, /*in_ch=*/64, w); + auto slice = sp.compute(1, te_ref); + float maxd = 0.f; + for (int i = 0; i < 64; i++) { + float d = std::fabs(slice.data()[i] - te_ref.data()[w * 64 + i]); + if (d > maxd) maxd = d; + } + std::printf(" shift-probe %-7s max_abs vs te[%d:%d]=%.3e\n", + which_names[w], w * 64, (w + 1) * 64, maxd); + pass &= (maxd < 1e-6f); + } + } + + // Per-stage trace now includes intermediates pushed INSIDE the first res_x block: + // 0 post_unnorm, 1 post_conv_in, 2 time_embed, 3 post_norm1, 4 shift1, 5 scale1, + // 6 post_adaln1, 7 post_conv1, 8 post_norm2, 9 up_block[0] out, ... + auto got_norm1 = dec_runner.compute(1, latent_ref, timestep_t, 3); + pass &= compare("resblock0 post_norm1", got_norm1, + ref_dir + "/tensors/dec_resblock0_post_norm1.bin", + {4, 4, 5, 64}, 2e-3f, 5e-4f); + auto got_adaln1 = dec_runner.compute(1, latent_ref, timestep_t, 6); + pass &= compare("resblock0 post_adaln1", got_adaln1, + ref_dir + "/tensors/dec_resblock0_post_adaln1.bin", + {4, 4, 5, 64}, 5e-3f, 1e-3f); + auto got_conv1 = dec_runner.compute(1, latent_ref, timestep_t, 7); + pass &= compare("resblock0 post_conv1", got_conv1, + ref_dir + "/tensors/dec_resblock0_post_conv1.bin", + {4, 4, 5, 64}, 5e-3f, 1e-3f); + auto got_norm2 = dec_runner.compute(1, latent_ref, timestep_t, 8); + pass &= compare("resblock0 post_norm2", got_norm2, + ref_dir + "/tensors/dec_resblock0_post_norm2.bin", + {4, 4, 5, 64}, 1e-2f, 2e-3f); + + // After causal=false + reflect-padding fixes, trace indices in the decoder have shifted. + // New layout: + // 0 post_unnorm 1 post_conv_in 2 time_embed 3 post_norm1 4 shift1 + // 5 scale1 6 post_adaln1 7 post_conv1 8 post_norm2 9 up_block[0] out + // 10 up_block[1] 11 up_block[2] 12 post_pixel_norm 13 post_ada + // 14 post_conv_out 15 video_out + struct Stage2 { int idx; const char* name; std::vector shape; float atol, mtol; }; + std::vector stages2 = { + // Shapes in ne-order (W, H, T, C) after each decoder block. + // Compress_time expands T 5→9; compress_space expands spatial 4→8 (patch=2 still + // to apply at the very end via unpatchify). + { 9, "dec_block_0", {4, 4, 5, 64}, 1e-2f, 2e-3f}, + {10, "dec_block_1", {4, 4, F, 64}, 1e-2f, 2e-3f}, + {11, "dec_block_2", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {12, "dec_post_pixel_norm", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {13, "dec_post_ada", {8, 8, F, 64}, 2e-2f, 4e-3f}, + {14, "dec_post_conv_out", {8, 8, F, 12}, 2e-2f, 4e-3f}, + }; + for (const auto& s : stages2) { + auto got = dec_runner.compute(1, latent_ref, timestep_t, s.idx); + pass &= compare(s.name, got, + ref_dir + "/tensors/" + std::string(s.name) + ".bin", + s.shape, s.atol, s.mtol); + } + + auto decoded = dec_runner.compute(1, latent_ref, timestep_t); + pass &= compare("dec video", decoded, ref_dir + "/tensors/video_out.bin", {W_, H, F, in_ch}, 1e-2f, 2e-3f); + + // Per-stage probe on the same runner (since GGMLRunner can be reused across + // multiple computes, as the encoder path does 9 times without issue). + if (!pass) { + std::printf("\n[dec] per-stage probe:\n"); + const char* stage_names[] = { + "dec_post_unnorm", "dec_post_conv_in", "dec_block_0", "dec_block_1", "dec_block_2", + "dec_post_pixel_norm", "dec_post_ada", "dec_post_conv_out", "video_out" + }; + for (int idx = 0; idx < 9; idx++) { + std::printf(" [%d] stage=%s computing...\n", idx, stage_names[idx]); std::fflush(stdout); + auto out = dec_runner.compute(1, latent_ref, timestep_t, idx); + std::string tag = stage_names[idx]; + std::string ref_path = ref_dir + "/tensors/" + tag + ".bin"; + std::ifstream check(ref_path); + if (check.good()) { + check.close(); + std::vector shape = {out.shape()[0], out.shape()[1], out.shape()[2], out.shape()[3]}; + auto ref = load_raw_bin(ref_path, shape); + if (ref.numel() != out.numel()) { + std::printf(" [%d] %-20s SHAPE_MISMATCH got=%ld ref=%ld (shape=%ld,%ld,%ld,%ld)\n", + idx, tag.c_str(), out.numel(), ref.numel(), + shape[0], shape[1], shape[2], shape[3]); + continue; + } + auto s = diff_fp32(out.data(), ref.data(), out.numel()); + std::printf(" [%d] %-20s max_abs=%.3e mean_abs=%.3e\n", + idx, tag.c_str(), s.max_abs, s.mean_abs); + } else { + float m0 = 0.f, m1 = 0.f; + for (int64_t i = 0; i < out.numel(); i++) { float a = std::fabs(out.data()[i]); m0 = std::max(m0, a); m1 += a; } + m1 /= out.numel() > 0 ? out.numel() : 1; + std::printf(" [%d] %-20s (no ref) shape=[%ld,%ld,%ld,%ld] max_abs=%.3f mean_abs=%.3f\n", + idx, tag.c_str(), + out.shape()[0], out.shape()[1], out.shape()[2], out.shape()[3], m0, m1); + } + } + } + + std::printf("\n%s\n", pass ? "ALL VAE PARITY: PASS" : "ALL VAE PARITY: FAIL"); + (void)B; + return pass ? 0 : 3; +}