diff --git a/docs/ltxv.md b/docs/ltxv.md new file mode 100644 index 000000000..b933c04c3 --- /dev/null +++ b/docs/ltxv.md @@ -0,0 +1,206 @@ +# LTX-Video 2.3 support — conditional text-to-video works end-to-end + +Branch: `feat/ltx-video` in +. Ports Lightricks' LTX-2.3 +22B audio-video foundation model (`Lightricks/LTX-2.3`) to +stable-diffusion.cpp, video-only path. **Text conditioning wired via a +native Gemma-3-12B port** so prompts actually steer the output. + +## Status — prompts generate the thing you asked for + +Validated on an NVIDIA GB10 (Grace Blackwell, CUDA 13, 119 GB unified memory) +with `ltx-2.3-22b-distilled.safetensors` (46 GB BF16) + Gemma-3-12B-it +(24 GB BF16) as text encoder: + +| Stage | Result | +|---|---| +| LTX version detection (`model.cpp`) | `VERSION_LTXV2` detected on `audio_scale_shift_table` / `audio_patchify_proj` / `audio_adaln_single` / `av_ca_video_scale_shift_adaln_single` / `video_embeddings_connector` | +| Weight registration | 4444 transformer + 170 VAE + 4 text_embedding_projection tensors registered — **zero missing, zero shape mismatches** vs. the 22B checkpoint | +| Checkpoint load | 46 GB BF16 loads in ~9 s; audio_vae / vocoder ignored (video-only pipeline) | +| Gemma-3-12B text encoder | Loads + runs in 5 s on GB10; 49-layer hidden states match HuggingFace to bf16 precision; `text_embedding_projection.video_aggregate_embed` output: std=6.828 (HF: 6.830) | +| Transformer forward | 48 layers × 32 heads × 128 head-dim (inner_dim 4096), 8 distilled steps in 123 s on GB10 | +| VAE decode | 9-block decoder with per-channel RMS norm + proper 3-D depth-to-space; 16-frame latent → 121-frame video in 16 s | +| End-to-end | 704×480×9 WebP in ~14 s; 768×512×121 WebP in ~140 s on GB10; **prompts generate the described subject** (cat → cat, dragon → dragon, etc.) | +| Quantization | BF16 46 GB → q8_0 28.3 GB via `sd-cli -M convert --type q8_0` in 9.6 s; q8_0 GGUF runs end-to-end | + +## What's in the code + +**Transformer (`src/ltxv.hpp`)** +- `LTX2VideoTransformer3DModel` — 48 layers; inner 4096 (32×128), cross-attn dim 4096, caption 4096 +- `LTXAttention` — qk_norm_across_heads, always-on gated attention (`to_gate_logits` + 2·σ), interleaved and split RoPE variants +- `LTX2VideoTransformerBlock` — per-block `scale_shift_table` (9, dim), `prompt_scale_shift_table` (2, dim), `scale_shift_table_a2v_ca_video/audio` (5, dim/audio_dim), `audio_scale_shift_table` (9, audio_dim), `audio_prompt_scale_shift_table` (2, audio_dim). Forward path runs **only** video self-attn + prompt cross-attn + FF; audio self-attn, a2v/v2a cross-attn and audio FFN are loaded but skipped (isolate_modalities=True). +- `AdaLayerNormSingle` with configurable `num_mod_params` +- `EmbeddingsConnector` — 128 learnable registers + 8 transformer_1d_blocks (gated self-attn + FF) for both video and audio +- Split 3-D RoPE (video-axis F/H/W, dim/6 freqs per axis, vae_scale_factors (8, 32, 32), `causal_offset=1`, fps scaling, pair-swap rotation) +- Stub `LTXV2Conditioner` returning zero embeddings of shape `[1, 128, 4096]` + +**VAE (`src/ltxv.hpp`)** +- 9-block encoder: res×4 @128, spatial↓(1,2,2) 128→256, res×6 @256, temporal↓(2,1,1) 256→512, res×4 @512, st↓(2,2,2) 512→1024, res×2 @1024, st↓(2,2,2) 1024→1024, res×2 @1024 +- Decoder is the exact mirror +- `VAEResBlock` is the LTX-2.3 simplified shape (two `CausalConv3d` with silu gates, no norms, no timestep modulation) +- `CausalConv3d` uses `conv.weight` / `conv.bias` names, hardcoded F16 dtype so it stays within the CUDA `ggml_cuda_op_im2col_3d` accepted types +- `VAEUpsampler` pixel-shuffle drops the first `st_t − 1` frames after each temporal upsample so `f_out = (f_in − 1) × st_t + 1` composes across all upsamples + +**Pipeline wiring (`src/stable-diffusion.cpp` etc.)** +- `VERSION_LTXV2` / `sd_version_is_ltxv2` / `sd_version_is_dit` entry +- VAE factory arm builds `LTXV::LTXVVAERunner` +- FLOW_PRED with `default_flow_shift = 3.0` +- Latent channels 128, VAE scale factor 32, temporal compression 8 +- Frame count padded to 8k+1 (LTX-2.3 I/O spec) +- Ignore prefixes: `audio_vae.`, `vocoder.`, `text_embedding_projection.` + +## Numerical correctness — resolved + +Nine bugs were diagnosed and fixed by working backwards from the VAE output +(and later the text-conditioning path) using graph-level probes. Each one is +noted here because the same mistake is easy to make again porting future +video VAE/DiT stacks: + +1. **EmbeddingsConnector pre-norm.** Reference + `_BasicTransformerBlock1D.forward` does `rms_norm(hidden_states)` before + both attn1 and ff (and a final `rms_norm` after the stack). We had + bare `x = x + attn(x); x = x + ff(x)` — residuals compounded across 8 + blocks and drove the connector output to std≈1e12, exploding cross-attn + in every transformer block. + +2. **Final `norm_out` before the scale/shift + `proj_out`.** Reference + `LTXModel._process_output` is + `x = norm_out(x); x = x * (1 + scale) + shift; x = proj_out(x)`. + Without the LayerNorm the post-block activation (std≈285 after 48 + layers) leaked into the predicted velocity and the sampler diverged. + Transformer output std went from 57 → 1.0 after adding `ggml_norm`. + +3. **VAE `conv_norm_out` + SiLU before `conv_out`.** The reference decoder + ends with `sample = conv_norm_out(sample); sample = silu(sample); + sample = conv_out(sample)`. We were skipping the PixelNorm+SiLU, so + output pixels were O(1000) instead of O(1). + +4. **Latent per-channel normalisation.** `vae.per_channel_statistics.*` + is now materialised to CPU and applied in `diffusion_to_vae_latents` + (`x * std + mean`) / `vae_to_diffusion_latents` (`(x - mean) / std`). + +5. **VAE depth-to-space ordering.** `ggml_reshape_4d` alone doesn't + implement einops `b (c p1 p2 p3) f h w -> b c (f p1) (h p2) (w p3)` — + the sub-indices come out in the wrong order. Replaced with a proper + `depth_to_space_3d` helper that decomposes the channel axis through + permute+cont passes so p3 lands inner-of-W, p2 inner-of-H, p1 + inner-of-F. Eliminated the visible banding. + +6. **Gemma-3 49-layer concat layout.** `ggml_concat(hidden_all[i], + axis=0)` produces a flat axis with layer-slow / hidden-fast ordering, + but HF's `reshape(B, T, D*L)` produces hidden-slow / layer-fast. + `text_embedding_projection.video_aggregate_embed` was trained for the + HF layout — a transposed input made the projection output essentially + noise and all prompts generated the same scene. Fixed by stacking + along axis 2 → permute(2, 0, 1, 3) → reshape to [D*L, T, 1]. + +7. **EmbeddingsConnector register layout.** Reference + `_replace_padded_with_learnable_registers` produces a **fixed + 128-token** output with real text at positions [0..L-1] and + `learnable_registers[L..127]` at [L..127]. We were concatenating + registers+text to 128+L tokens in the wrong order. Rewrote the + connector's register path. + +8. **Double attention scaling in Gemma-3.** Gemma-3 uses + `scale = 1/sqrt(query_pre_attn_scalar) = 1/sqrt(head_dim)` for the + 12B variant — and `ggml_ext_attention_ext` applies the same + `1/sqrt(d_head)` internally. Applying both multiplied the softmax + temperature by 1/16, collapsing attention to near-uniform and + producing a persistent ~sqrt(D) "attention sink" outlier at the same + hidden dim for every layer. Dropping the explicit Q scale made the + Gemma forward match HF to bf16 precision. + +9. **Two different patchify conventions in `ops.py` vs `sampling.py`.** + `DepthToSpaceUpsample` (intermediate upsamplers) uses + `b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)` — p3 (w-stride) + innermost in the channel axis. `ops.py::unpatchify` (the decoder's + final 4×4 un-patch) uses + `b (c p r q) f h w -> b c (f p) (h q) (w r)` — q (h_patch) innermost. + We were reusing the upsampler helper for the final unpatchify, which + silently transposed every 4×4 output block and left a visible fine- + scale hatching artefact that survived every diffusion step. Added a + dedicated `depth_to_space_3d_patch` that swaps the inner (p_w, p_h) + pair of the channel axis before delegating, matching the reference + layout exactly. + +Cross-checked against the 22B checkpoint's embedded config +(`safetensors __metadata__["config"]["vae"]`): `norm_layer=pixel_norm`, +`spatial_padding_mode=zeros`, `timestep_conditioning=false`, +`causal_decoder=false`, patch_size=4, and none of the `compress_all` +decoder blocks sets `residual=True` — so the residual skip from +`DepthToSpaceUpsample` is correctly absent here. + +End-to-end result: prompts now actually generate the described content. +Seed 42 with *"a cat walking across a grassy field"* produces exactly +that. Per-layer Gemma hidden states match HF to bf16 noise; the +projected cross-attention features match HF (min/max/std 0.0%/0.2%/0.03% +different). + +## Remaining items (future sessions) + +1. **Audio branch.** Roughly half of the LTX transformer buffer is + audio-related (`audio_attn1/2`, `audio_to_video_attn`, + `video_to_audio_attn`, `audio_embeddings_connector`, + `audio_scale_shift_table`, etc.). Adding joint audio+video generation + also needs the `audio_vae` (102 tensors), the HiFi-GAN-style + `vocoder` (1227 tensors), and the BWE upsampler. Non-trivial. + +2. **Schedule for non-distilled variants.** The 22B non-distilled model + uses LTX2Scheduler (token-count-dependent shift, stretched to a + terminal value). Only the distilled 8-step table is wired up today. + +3. **Quantised Gemma.** Gemma-3-12B is 24 GB in BF16. A q8_0 or q4_k + conversion would drop it to ~12 GB / ~7 GB — useful for smaller + hardware. The existing sd-cli `-M convert` path should handle it. + +## How to run the e2e test + +First, grab the two model artefacts: + +```bash +# LTX-2.3 distilled 22B (46 GB BF16 safetensors): +hf download Lightricks/LTX-2.3 ltx-2.3-22b-distilled.safetensors \ + --local-dir ltxv-models + +# Gemma-3-12B-it (tokenizer.model + 5x safetensors shards, ~24 GB BF16): +hf download google/gemma-3-12b-it --local-dir gemma-3-12b-it +``` + +Then run with the distilled 8-step schedule (auto-selected when +`--steps 8` is passed on an ltxv2 model): + +```bash +./sd-cli -M vid_gen \ + -m ltxv-models/ltx-2.3-22b-distilled.safetensors \ + --text-encoder gemma-3-12b-it \ + -p "a cat walking across a grassy field" \ + -W 704 -H 480 --video-frames 9 \ + --steps 8 --cfg-scale 1 \ + -o /tmp/ltx23.webp --seed 42 + +# Official distilled shape (768x512, 121 frames, ~140 s on GB10): +./sd-cli -M vid_gen \ + -m ltxv-models/ltx-2.3-22b-distilled.safetensors \ + --text-encoder gemma-3-12b-it \ + -p "a cat walking across a grassy field" \ + -W 768 -H 512 --video-frames 121 \ + --steps 8 --cfg-scale 1 \ + -o /tmp/ltx23.webp --seed 42 + +# Without --text-encoder: LTX runs unconditionally (zero embeddings), +# pipeline still produces valid frames but ignores the prompt. + +# Quantise the LTX DiT to q8_0 GGUF (46 GB -> 28 GB): +./sd-cli -M convert \ + -m ltxv-models/ltx-2.3-22b-distilled.safetensors \ + -o ltxv-models/ltx-2.3-22b-distilled-q8_0.gguf \ + --type q8_0 +``` + +## References + +- LTX-2.3 model card: https://huggingface.co/Lightricks/LTX-2.3 +- Diffusers LTX-2.0 reference (not an exact match for 2.3): + https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_ltx2.py +- Upstream ltx-pipelines (Lightricks): + https://github.com/Lightricks/LTX-2/tree/main/packages/ltx-pipelines diff --git a/docs/test_ltxv.sh b/docs/test_ltxv.sh new file mode 100644 index 000000000..abaa97f41 --- /dev/null +++ b/docs/test_ltxv.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# End-to-end LTX-2.3 test script for DGX. +# Run as: ssh dgx.casa 'bash -s' < /tmp/ltxv_test.sh + +set -e +set -o pipefail + +SD_CLI=~/ltxv-sd-cpp/build-cuda/bin/sd-cli +MODEL=~/ltxv-models/ltx-2.3-22b-distilled.safetensors +OUT=/tmp/ltx23_out + +mkdir -p "$OUT" +echo "==============================================" +echo "[1/3] vid_gen BF16 (no quant) — dry run" +echo "==============================================" +$SD_CLI -M vid_gen \ + -m "$MODEL" \ + -p "a cat walking across a grassy field" \ + -W 704 -H 480 --video-frames 9 \ + --steps 1 --cfg-scale 1 \ + -o "$OUT/dryrun.webp" \ + --seed 42 \ + -v 2>&1 | tail -80 + +echo "" +echo "==============================================" +echo "[2/3] Quantize to q8_0" +echo "==============================================" +$SD_CLI -M convert \ + -m "$MODEL" \ + -o "$OUT/ltx23_q8_0.gguf" \ + --type q8_0 \ + -v 2>&1 | tail -30 + +echo "" +echo "==============================================" +echo "[3/3] vid_gen with q8_0 GGUF" +echo "==============================================" +$SD_CLI -M vid_gen \ + -m "$OUT/ltx23_q8_0.gguf" \ + -p "a cat walking across a grassy field" \ + -W 704 -H 480 --video-frames 9 \ + --steps 4 --cfg-scale 1 \ + -o "$OUT/q8_output.webp" \ + --seed 42 \ + -v 2>&1 | tail -80 + +echo "" +echo "==============================================" +echo "Outputs in $OUT:" +ls -la "$OUT/" diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 2dcd1d53a..cffd8dccf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,4 +1,5 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(cli) -add_subdirectory(server) \ No newline at end of file +add_subdirectory(server) +add_subdirectory(gemma_test) \ No newline at end of file diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 2d29df267..48e7feba9 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -332,6 +332,12 @@ ArgOptions SDContextParams::get_options() { "--qwen2vl_vision", "alias of --llm_vision. Deprecated.", &llm_vision_path}, + {"", + "--text-encoder", + "path to the text encoder directory (e.g. google/gemma-3-12b-it for LTX-2.3). " + "Must contain tokenizer.model plus *.safetensors shards. " + "When unset, LTX-2.3 runs unconditionally.", + &text_encoder_path}, {"", "--diffusion-model", "path to the standalone diffusion model", @@ -744,6 +750,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f chroma_use_t5_mask, chroma_t5_mask_pad, qwen_image_zero_cond_t, + text_encoder_path.c_str(), }; return sd_ctx_params; } diff --git a/examples/common/common.h b/examples/common/common.h index 333d33116..8c78c95ed 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -90,6 +90,7 @@ struct SDContextParams { std::string t5xxl_path; std::string llm_path; std::string llm_vision_path; + std::string text_encoder_path; // LTX-2.3 Gemma-3 dir std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; diff --git a/examples/gemma_test/CMakeLists.txt b/examples/gemma_test/CMakeLists.txt new file mode 100644 index 000000000..b324b95a4 --- /dev/null +++ b/examples/gemma_test/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET gemma3-test) +add_executable(${TARGET} gemma3_test.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17) diff --git a/examples/gemma_test/gemma3_test.cpp b/examples/gemma_test/gemma3_test.cpp new file mode 100644 index 000000000..8ca89fd19 --- /dev/null +++ b/examples/gemma_test/gemma3_test.cpp @@ -0,0 +1,181 @@ +// Gemma-3-12B numerical validation. +// +// Loads Gemma-3-12B from HF safetensors (or GGUF), tokenises a prompt, +// runs the text-only transformer forward on a CUDA backend, and prints +// per-layer hidden-state statistics so we can diff against a HuggingFace +// reference. +// +// Usage: +// gemma3-test "" +// +// The first argument can be either a single safetensors path or a +// directory containing shard files (model-00001-of-00005.safetensors, +// etc.) — we glob "*.safetensors" in that directory. + +#include +#include +#include +#include +#include +#include + +#include "../../src/gemma3.hpp" +#include "../../src/ggml_extend.hpp" +#include "../../src/model.h" +#include "../../src/tokenizers/gemma3_tokenizer.h" + +static bool path_is_directory(const std::string& p) { + struct stat st; + if (stat(p.c_str(), &st) != 0) return false; + return S_ISDIR(st.st_mode); +} + +static std::vector list_safetensors(const std::string& dir) { + std::vector out; + DIR* d = opendir(dir.c_str()); + if (!d) return out; + struct dirent* e; + while ((e = readdir(d)) != nullptr) { + std::string name = e->d_name; + if (name.size() > 12 && name.substr(name.size() - 12) == ".safetensors") { + out.push_back(dir + "/" + name); + } + } + closedir(d); + std::sort(out.begin(), out.end()); + return out; +} + +static void log_stats(const char* label, const sd::Tensor& t) { + if (t.empty()) { + std::fprintf(stderr, "[stats] %s: EMPTY\n", label); + return; + } + int64_t n = t.numel(); + const float* d = t.data(); + double mn = 1e30, mx = -1e30, sum = 0, sq = 0; + size_t nan = 0; + for (int64_t i = 0; i < n; ++i) { + double v = d[i]; + if (std::isnan(v)) { nan++; continue; } + if (v < mn) mn = v; + if (v > mx) mx = v; + sum += v; + sq += v * v; + } + double mean = (n - nan) ? sum / (n - nan) : 0; + double var = (n - nan) ? (sq / (n - nan)) - mean * mean : 0; + double stdv = var > 0 ? std::sqrt(var) : 0; + std::string shape; + for (size_t i = 0; i < t.shape().size(); ++i) { + if (i) shape += "x"; + shape += std::to_string(t.shape()[i]); + } + std::fprintf(stderr, + "[stats] %-22s shape=[%s] min=%+.4f max=%+.4f mean=%+.4f std=%.4f nan=%zu\n", + label, shape.c_str(), mn, mx, mean, stdv, nan); +} + +int main(int argc, char** argv) { + if (argc < 4) { + std::fprintf(stderr, + "usage: %s \"\"\n", + argv[0]); + return 1; + } + std::string model_path = argv[1]; + std::string tok_path = argv[2]; + std::string prompt = argv[3]; + + // Tokenise. + Gemma3Tokenizer tok; + std::string terr; + if (!tok.load_from_spm(tok_path, &terr)) { + std::fprintf(stderr, "tokenizer load failed: %s\n", terr.c_str()); + return 2; + } + auto ids = tok.encode(prompt, /*add_bos=*/true, /*add_eos=*/false); + std::fprintf(stderr, "[tok] encoded %zu tokens: [", ids.size()); + for (size_t i = 0; i < ids.size() && i < 32; ++i) { + std::fprintf(stderr, "%s%d", i ? "," : "", ids[i]); + } + if (ids.size() > 32) std::fprintf(stderr, ",..."); + std::fprintf(stderr, "]\n"); + + // Build sd::Tensor input_ids [L, N=1]. + sd::Tensor input_ids(std::vector{(int64_t)ids.size(), 1}); + std::memcpy(input_ids.data(), ids.data(), ids.size() * sizeof(int32_t)); + + // Model loader: accept a single file or a directory of shards. + ModelLoader loader; + std::vector files; + if (path_is_directory(model_path)) { + files = list_safetensors(model_path); + } else { + files.push_back(model_path); + } + if (files.empty()) { + std::fprintf(stderr, "no safetensors files at %s\n", model_path.c_str()); + return 3; + } + for (const auto& f : files) { + std::fprintf(stderr, "[load] %s\n", f.c_str()); + if (!loader.init_from_file(f, /*prefix=*/"language_model.")) { + std::fprintf(stderr, "init_from_file failed: %s\n", f.c_str()); + return 4; + } + } + + // Backend. +#ifdef SD_USE_CUDA + ggml_backend_t backend = ggml_backend_cuda_init(0); + if (!backend) { + std::fprintf(stderr, "CUDA init failed; falling back to CPU\n"); + } +#else + ggml_backend_t backend = nullptr; +#endif + if (!backend) backend = ggml_backend_cpu_init(); + std::fprintf(stderr, "[be] %s\n", ggml_backend_name(backend)); + + // Build runner. The tensor map sees language_model.model.* keys from + // HF — our Runner prefix is "model." so together they resolve to the + // expected names (language_model.model.embed_tokens.weight etc). + GEMMA3::Gemma3Runner runner(backend, /*offload=*/false, + loader.get_tensor_storage_map(), + /*prefix=*/"model"); + if (!runner.alloc_params_buffer()) { + std::fprintf(stderr, "alloc_params_buffer failed\n"); + return 5; + } + + std::map tensors; + runner.get_param_tensors(tensors, /*prefix=*/"language_model.model"); + std::fprintf(stderr, "[load] mapping %zu tensors\n", tensors.size()); + if (!loader.load_tensors(tensors, /*ignore=*/{}, /*n_threads=*/4)) { + std::fprintf(stderr, "load_tensors failed\n"); + return 6; + } + + // Probe every 4th layer so we can diff against HF's hidden_states. + // Also dump the (tok=1, d=2339) value — known HF outlier position — + // to verify element-level agreement. + for (int probe : {0, 1, 2, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 47, 48}) { + auto h = runner.compute_layer_hidden(4, input_ids, probe); + char label[64]; + std::snprintf(label, sizeof(label), "layer[%d]", probe); + log_stats(label, h); + // h shape is [hidden=3840, T, 1]. Tok 1 is at offset hidden per + // the usual ggml layout (ne[0] fast). Index (d=2339, tok=1). + if (h.shape().size() >= 2 && h.shape()[0] > 2339 && h.shape()[1] > 1) { + int64_t idx = 1 * h.shape()[0] + 2339; + std::fprintf(stderr, " tok=1 d=2339: %.4f\n", h.data()[idx]); + } + } + + // Also dump the full 49-layer concatenated feature — what LTX's + // text_embedding_projection.video_aggregate_embed will consume. + auto cat = runner.compute_concatenated_hiddens(4, input_ids); + log_stats("concat[49*H]", cat); + return 0; +} diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 75027f8f8..6940d44b0 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -203,6 +203,10 @@ typedef struct { bool chroma_use_t5_mask; int chroma_t5_mask_pad; bool qwen_image_zero_cond_t; + // For LTX-2.3: directory containing Gemma-3-12B-it safetensors shards + // + tokenizer.model. When unset, LTXV2Conditioner returns zero + // embeddings (unconditional generation). + const char* text_encoder_path; } sd_ctx_params_t; typedef struct { diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 9f4d45524..da861f3d7 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -4,9 +4,11 @@ #include #include "clip.hpp" +#include "gemma3.hpp" #include "llm.hpp" #include "t5.hpp" #include "tensor_ggml.hpp" +#include "tokenizers/gemma3_tokenizer.h" struct SDCondition { sd::Tensor c_crossattn; @@ -96,6 +98,181 @@ struct Conditioner { } }; +// Small block that owns the LTX-2.3 `text_embedding_projection` linear +// (`video_aggregate_embed`, loaded from the LTX 22B safetensors) and +// applies it to a 188160-dim feature produced by Gemma3Runner. +struct LTXTextEmbedProjection : public GGMLRunner { + int64_t in_features; + int64_t out_features; + std::shared_ptr video_proj; + + LTXTextEmbedProjection(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + int64_t in_features, + int64_t out_features, + const std::string& prefix = "text_embedding_projection.video_aggregate_embed") + : GGMLRunner(backend, offload_params_to_cpu), + in_features(in_features), + out_features(out_features) { + // Force F32: the [4096 x 188160] matmul accumulates over 188k + // terms, and BF16 mantissa (~3 decimal digits) drops enough + // precision to visibly shrink the output range (std 5.2 vs HF's + // 6.8 on identical inputs). F32 brings us back within HF's noise. + video_proj = std::make_shared(in_features, out_features, /*bias=*/true, + /*force_f32=*/true, + /*force_prec_f32=*/true); + video_proj->init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltx_text_proj"; } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + video_proj->get_param_tensors(tensors, prefix); + } + + // x: [in_features, L, 1] + // returns: [out_features, L, 1] + sd::Tensor compute(int n_threads, const sd::Tensor& x) { + auto get_graph = [&]() -> ggml_cgraph* { + auto* gf = ggml_new_graph(compute_ctx); + auto xt = make_input(x); + auto rctx = get_context(); + auto y = video_proj->forward(&rctx, xt); + y = ggml_cont(compute_ctx, y); + ggml_build_forward_expand(gf, y); + return gf; + }; + auto result = GGMLRunner::compute(get_graph, n_threads, false); + if (!result.has_value()) return {}; + return std::move(*result); + } +}; + +// LTX-2.3 conditioner. +// +// When a Gemma-3-12B runner and tokenizer are provided, this path runs +// the full HF reference: +// prompt -> tokenize (Gemma SPM) -> Gemma-3 forward + 49-layer concat +// -> per-token RMSNorm + sqrt(out/D) rescale +// -> text_embedding_projection.video_aggregate_embed (Linear 188160->4096) +// -> c_crossattn [out=4096, L, 1] +// +// When no Gemma runner is provided (e.g. running for shape validation), +// we emit zero embeddings of the expected shape so the DiT can still run. +struct LTXV2Conditioner : public Conditioner { + int64_t caption_channels; + int64_t max_tokens; + std::shared_ptr gemma_runner; + std::shared_ptr gemma_tokenizer; + std::shared_ptr video_proj; + bool flash_attn_enabled = false; + + LTXV2Conditioner(int64_t caption_channels = 4096, int64_t max_tokens = 128) + : caption_channels(caption_channels), max_tokens(max_tokens) {} + + void attach_gemma(std::shared_ptr runner, + std::shared_ptr tokenizer, + std::shared_ptr proj) { + gemma_runner = std::move(runner); + gemma_tokenizer = std::move(tokenizer); + video_proj = std::move(proj); + } + + void alloc_params_buffer() override {} + void free_params_buffer() override {} + void get_param_tensors(std::map& tensors) override {} + size_t get_params_buffer_size() override { return 0; } + void set_flash_attention_enabled(bool enabled) override { + flash_attn_enabled = enabled; + if (gemma_runner) gemma_runner->set_flash_attention_enabled(enabled); + if (video_proj) video_proj->set_flash_attention_enabled(enabled); + } + + SDCondition get_learned_condition(int n_threads, + const ConditionerParams& conditioner_params) override { + SDCondition cond; + if (!gemma_runner || !gemma_tokenizer || !video_proj) { + // Fallback: zero embeddings (pipeline still runs for shape + // validation and for tests without a text encoder). + cond.c_crossattn = sd::zeros({caption_channels, max_tokens, 1}); + return cond; + } + + // Tokenize. Gemma base convention: BOS prepended, no EOS. + auto ids = gemma_tokenizer->encode(conditioner_params.text, + /*add_bos=*/true, /*add_eos=*/false); + // Truncate to max_tokens to keep the graph bounded. + if ((int64_t)ids.size() > max_tokens) { + ids.resize(max_tokens); + } + sd::Tensor input_ids(std::vector{(int64_t)ids.size(), 1}); + std::memcpy(input_ids.data(), ids.data(), ids.size() * sizeof(int32_t)); + + // Gemma → 188160-dim rescaled concat. + auto concat = gemma_runner->compute_concatenated_hiddens(n_threads, input_ids, + /*target_out_dim=*/caption_channels); + if (concat.empty()) { + LOG_WARN("Gemma forward failed — falling back to zero embeddings"); + cond.c_crossattn = sd::zeros({caption_channels, max_tokens, 1}); + return cond; + } + { + double mn = 1e30, mx = -1e30, sum = 0, sq = 0; + for (int64_t i = 0; i < concat.numel(); ++i) { + double v = concat.data()[i]; + if (v < mn) mn = v; + if (v > mx) mx = v; + sum += v; sq += v * v; + } + double mean = sum / concat.numel(); + double std = std::sqrt(std::max(0.0, sq / concat.numel() - mean * mean)); + LOG_INFO("[ltxv.cond] gemma_concat: shape=[%zu,%zu] min=%.3f max=%.3f mean=%.3f std=%.3f", + (size_t)concat.shape()[0], (size_t)concat.shape()[1], mn, mx, mean, std); + // Dump first/last 10 values of token 0 to /tmp for diff vs HF. + if (getenv("LTXV_DUMP_COND")) { + FILE* f = fopen("/tmp/ltxv_cond_concat.bin", "wb"); + if (f) { + fwrite(concat.data(), sizeof(float), concat.numel(), f); + fclose(f); + LOG_INFO("[ltxv.cond] dumped concat to /tmp/ltxv_cond_concat.bin (%zu floats)", + (size_t)concat.numel()); + } + } + } + // 188160 → caption_channels (4096). + auto projected = video_proj->compute(n_threads, concat); + if (projected.empty()) { + LOG_WARN("text_embedding_projection failed — falling back to zero embeddings"); + cond.c_crossattn = sd::zeros({caption_channels, max_tokens, 1}); + return cond; + } + { + double mn = 1e30, mx = -1e30, sum = 0, sq = 0; + for (int64_t i = 0; i < projected.numel(); ++i) { + double v = projected.data()[i]; + if (v < mn) mn = v; + if (v > mx) mx = v; + sum += v; sq += v * v; + } + double mean = sum / projected.numel(); + double std = std::sqrt(std::max(0.0, sq / projected.numel() - mean * mean)); + LOG_INFO("[ltxv.cond] projected: shape=[%zu,%zu] min=%.3f max=%.3f mean=%.3f std=%.3f", + (size_t)projected.shape()[0], (size_t)projected.shape()[1], mn, mx, mean, std); + if (getenv("LTXV_DUMP_COND")) { + FILE* f = fopen("/tmp/ltxv_cond_projected.bin", "wb"); + if (f) { + fwrite(projected.data(), sizeof(float), projected.numel(), f); + fclose(f); + LOG_INFO("[ltxv.cond] dumped projected to /tmp/ltxv_cond_projected.bin"); + } + } + } + cond.c_crossattn = std::move(projected); + return cond; + } +}; + // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index c0a2a11c0..d97de9d8c 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -5,6 +5,7 @@ #include "anima.hpp" #include "ernie_image.hpp" #include "flux.hpp" +#include "ltxv.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" #include "tensor_ggml.hpp" @@ -517,6 +518,49 @@ struct ZImageModel : public DiffusionModel { } }; +struct LTXV2Model : public DiffusionModel { + std::string prefix; + LTXV::LTXVRunner ltxv; + + LTXV2Model(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_LTXV2) + : prefix(prefix), ltxv(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { + } + + std::string get_desc() override { return ltxv.get_desc(); } + void alloc_params_buffer() override { ltxv.alloc_params_buffer(); } + void free_params_buffer() override { ltxv.free_params_buffer(); } + void free_compute_buffer() override { ltxv.free_compute_buffer(); } + void get_param_tensors(std::map& tensors) override { + ltxv.get_param_tensors(tensors, prefix); + } + size_t get_params_buffer_size() override { return ltxv.get_params_buffer_size(); } + void set_weight_adapter(const std::shared_ptr& adapter) override { + ltxv.set_weight_adapter(adapter); + } + int64_t get_adm_in_channels() override { return 2048; } + void set_flash_attention_enabled(bool enabled) override { + ltxv.set_flash_attention_enabled(enabled); + } + void set_circular_axes(bool circular_x, bool circular_y) override { + ltxv.set_circular_axes(circular_x, circular_y); + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + return ltxv.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context), + sd::Tensor()); // encoder attention mask (TODO: plumb through) + } +}; + struct ErnieImageModel : public DiffusionModel { std::string prefix; ErnieImage::ErnieImageRunner ernie_image; diff --git a/src/gemma3.hpp b/src/gemma3.hpp new file mode 100644 index 000000000..ceb3dc891 --- /dev/null +++ b/src/gemma3.hpp @@ -0,0 +1,692 @@ +// Gemma-3 text encoder for LTX-Video 2.3 conditioning. +// +// Architecture reference: llama.cpp src/models/gemma3.cpp (LLM_ARCH_GEMMA3) +// and HuggingFace transformers modeling_gemma3.py. +// +// Only the *text* sub-model is implemented — LTX-2.3 feeds the prompt +// through Gemma-3-12B-it's 48 transformer layers and concatenates the 49 +// resulting hidden states (input embedding + 48 layer outputs) along the +// last dim, then runs them through a per-modality linear (baked into the +// LTX-2.3 safetensors under `text_embedding_projection.*`) and through +// `video_embeddings_connector` to produce the cross-attention keys used +// by every block of the LTX video DiT. +// +// This file covers the GGML architecture + forward pass. Tokenizer and +// weight loading live in gemma3_tokenizer.{h,cpp} and gemma3_loader.{h,cpp}. +// +// Gemma-3-12B hyperparameters (from the model's config.json): +// hidden_size = 3840 intermediate_size = 15360 +// num_attention_heads = 16 num_key_value_heads = 8 (GQA, 2:1 ratio) +// head_dim = 256 num_hidden_layers = 48 +// rope_theta (global) = 1e6 rope_local_base_freq = 1e4 +// rope_scaling = linear factor 8 sliding_window = 1024 +// sliding_window_pattern = 6 (every 6th layer is full-attention) +// rms_norm_eps = 1e-6 +// query_pre_attn_scalar = 256 (attn_scale = 1 / sqrt(256) = 0.0625) +// hidden_activation = gelu_pytorch_tanh +// vocab_size = 262144 (tokens) + 64 special = 262208 + +#ifndef __GEMMA3_HPP__ +#define __GEMMA3_HPP__ + +#include +#include +#include +#include + +#include "common_block.hpp" +#include "ggml_extend.hpp" + +namespace GEMMA3 { + + constexpr int GEMMA3_GRAPH_SIZE = 32768; + + struct Gemma3Params { + int64_t hidden_size = 3840; + int64_t intermediate_size = 15360; + int64_t num_heads = 16; + int64_t num_kv_heads = 8; + int64_t head_dim = 256; + int64_t num_layers = 48; + int64_t vocab_size = 262208; + float rms_norm_eps = 1e-6f; + float rope_theta_global = 1e6f; + float rope_theta_local = 1e4f; + float rope_scaling_factor = 8.0f; // applied to GLOBAL rope only + int sliding_window = 1024; + int sliding_window_pattern = 6; // global attn every Nth layer + float query_pre_attn_scalar = 256.0f; // attn_scale = 1/sqrt(q_pre) + float embed_scale_sqrt_embd = 1.0f; // filled in ctor (sqrt(hidden_size)) + }; + + // Gemma-3 RMSNorm: applies `(1 + weight)` rather than `weight`, so the + // checkpoint stores weights initialised at 0. Equivalent to + // out = x * rsqrt(mean(x^2) + eps) * (1 + w) + class Gemma3RMSNorm : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + std::string prefix; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + this->prefix = prefix; + params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + + public: + Gemma3RMSNorm(int64_t hidden_size, float eps = 1e-6f) + : hidden_size(hidden_size), eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + ggml_tensor* w = params["weight"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + // Equivalent to `x * (1 + w)` — add a fresh f32 "1" tensor of + // matching shape, or use ggml_add with a constant. ggml_scale + // on `x` would need two ops; cleanest is to materialise + // `(1 + w)` at graph-build time, but `w` lives on the backend. + // So we do it with two ggml ops: tmp = x * w + x == x * (1+w). + auto mul = ggml_mul(ctx->ggml_ctx, x, w); + return ggml_add(ctx->ggml_ctx, mul, x); + } + }; + + // Gemma-3 MLP: SwiGLU variant using GELU (pytorch_tanh approximation). + // out = down(gelu_tanh(gate(x)) * up(x)) + class Gemma3MLP : public GGMLBlock { + public: + Gemma3MLP(int64_t hidden_size, int64_t intermediate_size) { + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, /*bias=*/false)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, /*bias=*/false)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, /*bias=*/false)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto gate = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up = std::dynamic_pointer_cast(blocks["up_proj"]); + auto down = std::dynamic_pointer_cast(blocks["down_proj"]); + + auto g = gate->forward(ctx, x); + g = ggml_gelu_inplace(ctx->ggml_ctx, g); // tanh-approx + auto u = up->forward(ctx, x); + auto h = ggml_mul_inplace(ctx->ggml_ctx, g, u); + return down->forward(ctx, h); + } + }; + + // Single Gemma-3 decoder block. + // attn_branch : pre_attn_norm -> (Q,K,V) -> q_norm/k_norm -> RoPE + // -> GQA (sliding-window or global) -> post_attn_norm + // -> residual + // ffn_branch : pre_ffn_norm -> Gemma3MLP -> post_ffn_norm -> residual + class Gemma3Block : public GGMLBlock { + protected: + Gemma3Params params_; + int layer_idx; + + public: + Gemma3Block(const Gemma3Params& p, int layer_idx) : params_(p), layer_idx(layer_idx) { + int64_t q_dim = p.num_heads * p.head_dim; + int64_t kv_dim = p.num_kv_heads * p.head_dim; + + blocks["input_layernorm"] = std::shared_ptr(new Gemma3RMSNorm(p.hidden_size, p.rms_norm_eps)); + blocks["post_attention_layernorm"] = std::shared_ptr(new Gemma3RMSNorm(p.hidden_size, p.rms_norm_eps)); + blocks["pre_feedforward_layernorm"] = std::shared_ptr(new Gemma3RMSNorm(p.hidden_size, p.rms_norm_eps)); + blocks["post_feedforward_layernorm"] = std::shared_ptr(new Gemma3RMSNorm(p.hidden_size, p.rms_norm_eps)); + + blocks["self_attn.q_proj"] = std::shared_ptr(new Linear(p.hidden_size, q_dim, /*bias=*/false)); + blocks["self_attn.k_proj"] = std::shared_ptr(new Linear(p.hidden_size, kv_dim, /*bias=*/false)); + blocks["self_attn.v_proj"] = std::shared_ptr(new Linear(p.hidden_size, kv_dim, /*bias=*/false)); + blocks["self_attn.o_proj"] = std::shared_ptr(new Linear(q_dim, p.hidden_size, /*bias=*/false)); + blocks["self_attn.q_norm"] = std::shared_ptr(new Gemma3RMSNorm(p.head_dim, p.rms_norm_eps)); + blocks["self_attn.k_norm"] = std::shared_ptr(new Gemma3RMSNorm(p.head_dim, p.rms_norm_eps)); + + blocks["mlp"] = std::shared_ptr(new Gemma3MLP(p.hidden_size, p.intermediate_size)); + } + + // Returns (layer_output, residual_after_attn) — the latter is useful + // for the final hidden-state list. We concatenate per-layer outputs + // outside this class. + // + // rope_cos/rope_sin: precomputed per-token cos/sin tables. The caller + // picks the right one (local for sliding layers, global for full). + // attn_mask: [L, L] additive mask; caller builds the sliding-window + // band or leaves nullptr for full-attention layers. + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* rope_cos, + ggml_tensor* rope_sin, + ggml_tensor* attn_mask /* may be nullptr */) { + auto in_norm = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attn = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + auto pre_ffn = std::dynamic_pointer_cast(blocks["pre_feedforward_layernorm"]); + auto post_ffn = std::dynamic_pointer_cast(blocks["post_feedforward_layernorm"]); + + auto q_proj = std::dynamic_pointer_cast(blocks["self_attn.q_proj"]); + auto k_proj = std::dynamic_pointer_cast(blocks["self_attn.k_proj"]); + auto v_proj = std::dynamic_pointer_cast(blocks["self_attn.v_proj"]); + auto o_proj = std::dynamic_pointer_cast(blocks["self_attn.o_proj"]); + auto q_norm = std::dynamic_pointer_cast(blocks["self_attn.q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["self_attn.k_norm"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + auto residual = x; + + // --- attention branch --- + auto h = in_norm->forward(ctx, x); + auto q = q_proj->forward(ctx, h); // [q_dim, L, N] + auto k = k_proj->forward(ctx, h); // [kv_dim, L, N] + auto v = v_proj->forward(ctx, h); // [kv_dim, L, N] + + int64_t L = q->ne[1]; + int64_t N = q->ne[2]; + + // q_norm / k_norm are PER-HEAD — reshape to expose head_dim on + // the inner axis, apply RMSNorm, reshape back. + q = ggml_reshape_4d(ctx->ggml_ctx, q, params_.head_dim, params_.num_heads, L, N); + k = ggml_reshape_4d(ctx->ggml_ctx, k, params_.head_dim, params_.num_kv_heads, L, N); + v = ggml_reshape_4d(ctx->ggml_ctx, v, params_.head_dim, params_.num_kv_heads, L, N); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + + // Apply RoPE to Q and K. Q has num_heads heads; K has num_kv_heads. + // RoPE tables are shape [head_dim/2 or head_dim, L] depending on + // variant — here we use interleaved (standard Gemma-3). + q = apply_rotary_emb(ctx, q, rope_cos, rope_sin); + k = apply_rotary_emb(ctx, k, rope_cos, rope_sin); + + // Gemma-3 uses `scale = 1/sqrt(query_pre_attn_scalar)` — for + // Gemma-3-12B this equals `1/sqrt(head_dim)` (both are 256), + // which ggml_ext_attention_ext applies internally. If this + // assumption ever breaks (e.g. 27B), apply the corrective + // factor (sqrt(head_dim) / sqrt(query_pre_attn_scalar)) here. + GGML_ASSERT(params_.query_pre_attn_scalar == params_.head_dim); + + // GQA: K and V each map to num_heads by repeat (num_heads / + // num_kv_heads copies). ggml's attention helper handles this + // when we pass K/V with num_kv_heads directly if the backend + // supports broadcasting; otherwise we tile. + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, + ggml_reshape_3d(ctx->ggml_ctx, q, + params_.head_dim * params_.num_heads, L, N), + ggml_reshape_3d(ctx->ggml_ctx, k, + params_.head_dim * params_.num_kv_heads, L, N), + ggml_reshape_3d(ctx->ggml_ctx, v, + params_.head_dim * params_.num_kv_heads, L, N), + params_.num_heads, + attn_mask, + /*scale_for_sdp=*/false, + ctx->flash_attn_enabled); + auto attn = o_proj->forward(ctx, attn_out); + attn = post_attn->forward(ctx, attn); + x = ggml_add(ctx->ggml_ctx, residual, attn); + + // --- FFN branch --- + residual = x; + auto ff = pre_ffn->forward(ctx, x); + ff = mlp->forward(ctx, ff); + ff = post_ffn->forward(ctx, ff); + return ggml_add(ctx->ggml_ctx, residual, ff); + } + + private: + // NEOX-style RoPE (matches Gemma-3 in llama.cpp): pair is + // (x[k], x[k + D/2]) for k in [0, D/2) + // rotation: + // x_new[k] = x[k] * cos[k] - x[k + D/2] * sin[k] + // x_new[k + D/2] = x[k + D/2] * cos[k] + x[k] * sin[k] + // cos/sin are [D, L] (duplicated: cos[k] == cos[k+D/2], same for sin) + // so we can apply via element-wise multiplies. + static ggml_tensor* apply_rotary_emb(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* cos, + ggml_tensor* sin) { + // x: [head_dim, n_heads, L, N] + int64_t D = x->ne[0]; + int64_t H = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + int64_t r = D / 2; + + // Split x along the head_dim axis into first half (k=0..r-1) + // and second half (k=r..D-1), both shape [r, H, L, N]. + // In ggml ne order, ne[0] is innermost; use views into the + // contiguous memory. + auto first = ggml_view_4d(ctx->ggml_ctx, x, r, H, L, N, + x->nb[1], x->nb[2], x->nb[3], 0); + auto second = ggml_view_4d(ctx->ggml_ctx, x, r, H, L, N, + x->nb[1], x->nb[2], x->nb[3], + x->nb[0] * r); + first = ggml_cont(ctx->ggml_ctx, first); + second = ggml_cont(ctx->ggml_ctx, second); + + // cos / sin broadcast over (H, N). + auto cos_b = ggml_reshape_4d(ctx->ggml_ctx, cos, D, 1, L, 1); + auto sin_b = ggml_reshape_4d(ctx->ggml_ctx, sin, D, 1, L, 1); + auto cos_first = ggml_view_4d(ctx->ggml_ctx, cos_b, r, 1, L, 1, + cos_b->nb[1], cos_b->nb[2], cos_b->nb[3], 0); + auto sin_first = ggml_view_4d(ctx->ggml_ctx, sin_b, r, 1, L, 1, + sin_b->nb[1], sin_b->nb[2], sin_b->nb[3], 0); + cos_first = ggml_cont(ctx->ggml_ctx, cos_first); + sin_first = ggml_cont(ctx->ggml_ctx, sin_first); + + // first_new = first * cos - second * sin + // second_new = second * cos + first * sin + auto first_new = ggml_sub(ctx->ggml_ctx, + ggml_mul(ctx->ggml_ctx, first, cos_first), + ggml_mul(ctx->ggml_ctx, second, sin_first)); + auto second_new = ggml_add(ctx->ggml_ctx, + ggml_mul(ctx->ggml_ctx, second, cos_first), + ggml_mul(ctx->ggml_ctx, first, sin_first)); + + // Concatenate back along head_dim. + return ggml_concat(ctx->ggml_ctx, first_new, second_new, 0); + } + }; + + // Full Gemma-3 text model: embedding + 48 decoder blocks + final RMSNorm. + // Exposes `forward_with_hidden_states` that returns all 49 intermediate + // hidden states (post-embedding + each of 48 layer outputs) so the LTX + // embeddings processor can concatenate them. + class Gemma3TextModel : public GGMLBlock { + public: + Gemma3Params params_; + + Gemma3TextModel(const Gemma3Params& p) : params_(p) { + blocks["embed_tokens"] = std::shared_ptr(new Embedding(p.vocab_size, p.hidden_size)); + for (int64_t i = 0; i < p.num_layers; ++i) { + blocks["layers." + std::to_string(i)] = + std::shared_ptr(new Gemma3Block(p, (int)i)); + } + blocks["norm"] = std::shared_ptr(new Gemma3RMSNorm(p.hidden_size, p.rms_norm_eps)); + } + + // input_ids: [L, N=1] int32 + // rope_cos_global / rope_sin_global: [head_dim, L] (global θ+scaling) + // rope_cos_local / rope_sin_local: [head_dim, L] (local θ, no scaling) + // sliding_mask: [L, L] additive mask with the 1024-band; full layers + // use nullptr. + // hidden_out: caller-provided vector to receive intermediate tensors. + // After a full forward it will have num_layers+1 + // entries: [post-embed, layer0_out, ..., layer_{N-1}_out]. + // The last entry is REPLACED with the post-final-norm + // result on return. + // + // max_layers: run at most this many decoder blocks; -1 = all. + // When < num_layers, the final norm is NOT applied and + // hidden_out contains [post-embed, layer0_out, ..., + // layer_{max_layers-1}_out]. + ggml_tensor* forward_with_hidden_states(GGMLRunnerContext* ctx, + ggml_tensor* input_ids, + ggml_tensor* rope_cos_global, + ggml_tensor* rope_sin_global, + ggml_tensor* rope_cos_local, + ggml_tensor* rope_sin_local, + ggml_tensor* sliding_mask, + ggml_tensor* full_mask, + std::vector& hidden_out, + int64_t max_layers = -1) { + auto embed = std::dynamic_pointer_cast(blocks["embed_tokens"]); + auto fnorm = std::dynamic_pointer_cast(blocks["norm"]); + + auto x = embed->forward(ctx, input_ids); + // Gemma paper: embeddings scaled by sqrt(hidden_size). + x = ggml_scale(ctx->ggml_ctx, x, std::sqrt((float)params_.hidden_size)); + + int64_t lim = max_layers < 0 ? params_.num_layers + : std::min(max_layers, params_.num_layers); + hidden_out.clear(); + hidden_out.reserve(lim + 1); + hidden_out.push_back(x); + + for (int64_t i = 0; i < lim; ++i) { + auto blk = std::dynamic_pointer_cast( + blocks["layers." + std::to_string(i)]); + bool is_global = ((i + 1) % params_.sliding_window_pattern) == 0; + auto* cos = is_global ? rope_cos_global : rope_cos_local; + auto* sin = is_global ? rope_sin_global : rope_sin_local; + // Gemma-3 uses CAUSAL attention everywhere. Full-attention + // layers get a plain causal mask; sliding layers get a + // windowed causal mask. Caller provides both under the + // `full_mask` / `sliding_mask` names. + auto* msk = is_global ? full_mask : sliding_mask; + x = blk->forward(ctx, x, cos, sin, msk); + hidden_out.push_back(x); + } + if (lim == params_.num_layers) { + x = fnorm->forward(ctx, x); + // Replace last entry with post-final-norm. + hidden_out.back() = x; + } + return x; + } + }; + + // Precompute interleaved RoPE tables on CPU. The LTX pipeline encodes a + // single short prompt (max ~256 tokens); we materialise the full + // [head_dim, L] cos/sin once per run. + struct RopeTables { + std::vector cos; + std::vector sin; + int64_t L = 0; + int64_t dim = 0; + }; + + __STATIC_INLINE__ RopeTables compute_gemma3_rope(int64_t L, + int64_t head_dim, + float theta, + float scaling_factor) { + RopeTables t; + t.L = L; + t.dim = head_dim; + t.cos.assign(L * head_dim, 0.f); + t.sin.assign(L * head_dim, 0.f); + // NEOX RoPE layout: pairs are (x[k], x[k+D/2]). + // cos[pos*D + k] = cos[pos*D + k + D/2] = cos(scaled_pos * freq_k) + // i.e. the first half of the dim holds the values and the second + // half is a duplicate — so `apply_rotary_emb` can just broadcast. + // freq_k = 1 / theta^(2k / head_dim) for k in [0, D/2). + int64_t half = head_dim / 2; + for (int64_t pos = 0; pos < L; ++pos) { + float scaled_pos = (float)pos / scaling_factor; + for (int64_t k = 0; k < half; ++k) { + float freq = 1.0f / std::pow(theta, (float)(2 * k) / (float)head_dim); + float ang = scaled_pos * freq; + float c = std::cos(ang); + float s = std::sin(ang); + t.cos[pos * head_dim + k] = c; + t.cos[pos * head_dim + k + half] = c; + t.sin[pos * head_dim + k] = s; + t.sin[pos * head_dim + k + half] = s; + } + } + return t; + } + + // Build an additive causal sliding-window mask of shape [L, L]: + // mask[i, j] = 0 if j <= i && i - j < window + // = -inf otherwise + // Gemma-3 uses causal attention for both sliding and full layers + // (`use_bidirectional_attention = False` in the text_config). For full- + // attention layers, pass `window = L` to get a plain causal mask. + __STATIC_INLINE__ std::vector build_causal_mask(int64_t L, int window) { + std::vector m(L * L, -INFINITY); + for (int64_t i = 0; i < L; ++i) { + int64_t lo = std::max(0, i - window + 1); + for (int64_t j = lo; j <= i; ++j) { + m[i * L + j] = 0.0f; + } + } + return m; + } + + // Back-compat shim. + __STATIC_INLINE__ std::vector build_sliding_mask(int64_t L, int window) { + return build_causal_mask(L, window); + } + + // GGMLRunner wrapper: allocates params_buffer, builds graph per call. + // Owns two sets of precomputed RoPE tables (local + global) and the + // sliding mask, uploaded to the backend per compute() invocation. + struct Gemma3Runner : public GGMLRunner { + Gemma3Params params; + Gemma3TextModel model; + RopeTables rope_global; + RopeTables rope_local; + std::vector sliding_mask; + std::vector full_mask; + + Gemma3Runner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix = "model") + : GGMLRunner(backend, offload_params_to_cpu), model(params) { + model.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "gemma3_12b"; } + + void get_param_tensors(std::map& tensors, + const std::string prefix) { + model.get_param_tensors(tensors, prefix); + } + + // Build graph, set RoPE + mask tensors, run, return the last layer's + // hidden state (shape [hidden, L, 1]) as sd::Tensor. For LTX we will + // also need the INTERMEDIATE hidden states — see compute_all_layers. + ggml_cgraph* build_graph(const sd::Tensor& input_ids, + const std::vector& hidden_out_slots, + bool want_final = true) { + auto gf = ggml_new_graph_custom(compute_ctx, GEMMA3_GRAPH_SIZE, false); + auto ids_t = make_input(input_ids); + int64_t L = ids_t->ne[0]; + + // Lazily rebuild rope / mask to match L. + if (rope_global.L != L) { + rope_global = compute_gemma3_rope(L, params.head_dim, params.rope_theta_global, params.rope_scaling_factor); + rope_local = compute_gemma3_rope(L, params.head_dim, params.rope_theta_local, /*scaling=*/1.0f); + sliding_mask = build_causal_mask(L, params.sliding_window); + full_mask = build_causal_mask(L, (int)L); + } + + auto rope_cos_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_sin_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_cos_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_sin_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto mask_s = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + auto mask_f = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + set_backend_tensor_data(rope_cos_g, rope_global.cos.data()); + set_backend_tensor_data(rope_sin_g, rope_global.sin.data()); + set_backend_tensor_data(rope_cos_l, rope_local.cos.data()); + set_backend_tensor_data(rope_sin_l, rope_local.sin.data()); + set_backend_tensor_data(mask_s, sliding_mask.data()); + set_backend_tensor_data(mask_f, full_mask.data()); + + auto rctx = get_context(); + std::vector hidden_all; + auto out = model.forward_with_hidden_states(&rctx, ids_t, + rope_cos_g, rope_sin_g, + rope_cos_l, rope_sin_l, + mask_s, mask_f, hidden_all); + // Publish the hidden states the caller asked for. + GGML_ASSERT(hidden_out_slots.size() <= hidden_all.size()); + for (size_t i = 0; i < hidden_out_slots.size(); ++i) { + if (hidden_out_slots[i]) *hidden_out_slots[i] = hidden_all[i]; + } + // Expand all requested hidden states first so graph scheduling + // keeps them reachable, then `out` last so it remains final. + for (auto* h : hidden_all) ggml_build_forward_expand(gf, h); + if (want_final) ggml_build_forward_expand(gf, out); + return gf; + } + + // Compute and return ONE specific hidden state by index. + // layer_idx=0 → post-embed; 1..num_layers-1 → post-block i-1; + // num_layers → post-final-norm (full model output). + // + // Implementation note: we inline the forward pass here and STOP + // when we reach the target layer, so the graph's last node is + // exactly the tensor we want. This bypasses the gallocr buffer- + // reuse surprise that makes hidden_out entries unreadable after + // later layers overwrite them. + sd::Tensor compute_layer_hidden(int n_threads, + const sd::Tensor& input_ids, + int layer_idx) { + auto get_graph = [&]() -> ggml_cgraph* { + auto* gf = ggml_new_graph_custom(compute_ctx, GEMMA3_GRAPH_SIZE, false); + auto ids_t = make_input(input_ids); + int64_t L = ids_t->ne[0]; + if (rope_global.L != L) { + rope_global = compute_gemma3_rope(L, params.head_dim, + params.rope_theta_global, + params.rope_scaling_factor); + rope_local = compute_gemma3_rope(L, params.head_dim, + params.rope_theta_local, 1.0f); + sliding_mask = build_causal_mask(L, params.sliding_window); + full_mask = build_causal_mask(L, (int)L); + } + auto rctx = get_context(); + // Conditionally create the input tensors we'll actually + // use. `set_backend_tensor_data` is only called for tensors + // we DEFINITELY put in the graph — otherwise compute<> + // tries to upload data to unallocated tensors and asserts. + // + // For layer_idx=N (num_layers or -1), all layers run, so we + // need both global and local RoPE + mask. For a truncated + // forward we compute which RoPE families are required. + int64_t max_layers = (layer_idx < 0) ? params.num_layers + : (int64_t)layer_idx; + bool need_global = false; + bool need_local = false; + for (int64_t i = 0; i < max_layers; ++i) { + bool is_global = ((i + 1) % params.sliding_window_pattern) == 0; + if (is_global) need_global = true; + else need_local = true; + } + + ggml_tensor* rope_cos_g = nullptr; + ggml_tensor* rope_sin_g = nullptr; + ggml_tensor* rope_cos_l = nullptr; + ggml_tensor* rope_sin_l = nullptr; + ggml_tensor* mask_s = nullptr; + ggml_tensor* mask_f = nullptr; + if (need_global) { + rope_cos_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + rope_sin_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + mask_f = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + set_backend_tensor_data(rope_cos_g, rope_global.cos.data()); + set_backend_tensor_data(rope_sin_g, rope_global.sin.data()); + set_backend_tensor_data(mask_f, full_mask.data()); + } + if (need_local) { + rope_cos_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + rope_sin_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + mask_s = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + set_backend_tensor_data(rope_cos_l, rope_local.cos.data()); + set_backend_tensor_data(rope_sin_l, rope_local.sin.data()); + set_backend_tensor_data(mask_s, sliding_mask.data()); + } + + std::vector hidden_all; + model.forward_with_hidden_states(&rctx, ids_t, + rope_cos_g, rope_sin_g, + rope_cos_l, rope_sin_l, + mask_s, mask_f, + hidden_all, max_layers); + ggml_tensor* pick = hidden_all.back(); + auto pick_out = ggml_cont(compute_ctx, pick); + ggml_build_forward_expand(gf, pick_out); + return gf; + }; + auto result = GGMLRunner::compute(get_graph, n_threads, false); + if (!result.has_value()) return {}; + return std::move(*result); + } + + // Compute all 49 hidden states, apply LTX-2.3's FeatureExtractorV2 + // normalisation (per-token per-layer RMS-norm along the hidden + // axis), concatenate along the channel axis, and rescale by + // sqrt(out/hidden). Returns [188160, L, 1]. + // + // This is the exact input that + // `text_embedding_projection.video_aggregate_embed` (a Linear from + // the LTX-2.3 22B safetensors) expects — projecting to 4096-dim + // cross-attention features for the video DiT. + // + // Reference: ltx_core.text_encoders.gemma.feature_extractor + // FeatureExtractorV2.forward: + // encoded = stack(hidden_states, dim=-1) # [B, T, D, L] + // normed = norm_and_concat_per_token_rms(...) # [B, T, D*L] + // normed *= sqrt(out/D) + // return video_aggregate_embed(normed) # [B, T, out] + sd::Tensor compute_concatenated_hiddens(int n_threads, + const sd::Tensor& input_ids, + int64_t target_out_dim = 4096) { + auto get_graph = [&]() -> ggml_cgraph* { + auto* gf = ggml_new_graph_custom(compute_ctx, GEMMA3_GRAPH_SIZE, false); + auto ids_t = make_input(input_ids); + int64_t L = ids_t->ne[0]; + if (rope_global.L != L) { + rope_global = compute_gemma3_rope(L, params.head_dim, + params.rope_theta_global, + params.rope_scaling_factor); + rope_local = compute_gemma3_rope(L, params.head_dim, + params.rope_theta_local, 1.0f); + sliding_mask = build_causal_mask(L, params.sliding_window); + full_mask = build_causal_mask(L, (int)L); + } + auto rope_cos_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_sin_g = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_cos_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto rope_sin_l = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, params.head_dim, L); + auto mask_s = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + auto mask_f = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, L, L); + set_backend_tensor_data(rope_cos_g, rope_global.cos.data()); + set_backend_tensor_data(rope_sin_g, rope_global.sin.data()); + set_backend_tensor_data(rope_cos_l, rope_local.cos.data()); + set_backend_tensor_data(rope_sin_l, rope_local.sin.data()); + set_backend_tensor_data(mask_s, sliding_mask.data()); + set_backend_tensor_data(mask_f, full_mask.data()); + + auto rctx = get_context(); + std::vector hidden_all; + model.forward_with_hidden_states(&rctx, ids_t, + rope_cos_g, rope_sin_g, + rope_cos_l, rope_sin_l, + mask_s, mask_f, hidden_all); + // FeatureExtractorV2: per-token RMSNorm along the hidden + // axis for EACH layer, then concat on the channel axis, + // then rescale by sqrt(out/hidden). + // + // IMPORTANT layout: the reference stacks hidden_states into + // [B, T, D, L] (with layer L as the LAST axis) and then + // reshape(B, T, D*L). That produces a flat axis whose fast + // index is L (layer) and slow index is D (hidden). The + // video_aggregate_embed Linear weight [4096, 188160] was + // trained with this exact ordering, so we must match it. + // + // ggml_concat along axis 0 on a list of [D, T, 1, 1] tensors + // yields [D*L, T] with the OPPOSITE order (D fast, L slow), + // so we instead stack into [D, T, L, 1], permute to + // [L, D, T, 1] and flatten to get (d slow, l fast). + GGML_ASSERT(hidden_all.size() > 0); + for (size_t i = 0; i < hidden_all.size(); ++i) { + // Per-token RMSNorm along ne[0]=D (innermost). + hidden_all[i] = ggml_rms_norm(compute_ctx, hidden_all[i], 1e-6f); + } + // Stack along axis 2 (layer axis): [D, T, L, 1]. + ggml_tensor* cat = hidden_all[0]; + if (hidden_all.size() > 1) { + for (size_t i = 1; i < hidden_all.size(); ++i) { + cat = ggml_concat(compute_ctx, cat, hidden_all[i], 2); + } + } + int64_t L_tok = cat->ne[1]; + int64_t L_lay = (int64_t)hidden_all.size(); + int64_t D = params.hidden_size; + // Permute to put L (layer) as ne[0] (fastest): [L, D, T, 1]. + cat = ggml_cont(compute_ctx, + ggml_ext_torch_permute(compute_ctx, cat, 2, 0, 1, 3)); + // Flatten into [D*L, T, 1] — fast index is L, slow is D — + // matching HF's `reshape(B, T, D*L)` layout expected by + // text_embedding_projection.video_aggregate_embed. + cat = ggml_reshape_3d(compute_ctx, cat, D * L_lay, L_tok, 1); + // Rescale: multiply by sqrt(target_out_dim / hidden_size). + float scale = std::sqrt((float)target_out_dim / (float)params.hidden_size); + cat = ggml_scale(compute_ctx, cat, scale); + cat = ggml_cont(compute_ctx, cat); + ggml_build_forward_expand(gf, cat); + return gf; + }; + auto result = GGMLRunner::compute(get_graph, n_threads, false); + if (!result.has_value()) return {}; + return std::move(*result); + } + }; + +} // namespace GEMMA3 + +#endif // __GEMMA3_HPP__ diff --git a/src/ltxv.hpp b/src/ltxv.hpp index fb37dbe02..08028c659 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -1,73 +1,1833 @@ #ifndef __LTXV_HPP__ #define __LTXV_HPP__ +// LTX-Video 2.3 (Lightricks) port targeting +// Lightricks/LTX-2.3/ltx-2.3-22b-dev.safetensors (22B params, 5947 tensors) +// and its distilled siblings (8-step, CFG=1). +// +// The weight layout is inferred directly from the safetensors header of the +// official 22B checkpoint — diffusers' `transformer_ltx2.py` is a close but +// NOT identical reference (names and block counts differ for LTX-2.3). +// +// Scope: VIDEO-ONLY generation. +// * Every weight in the checkpoint (including audio self-attn, a2v/v2a +// cross-attn, audio FFN, audio VAE) is registered so loading succeeds. +// * The forward path exercises only the video branch — audio hidden state +// stays at zeros and the audio-to-video/video-to-audio paths are skipped +// (equivalent to diffusers `isolate_modalities=True` + discarding audio +// output). Enable them later for audio generation. +// +// Tensor-layout conventions: +// * torch (N, C, F, H, W) video is stored in ggml as ne = [W, H, F, C*N] +// * torch (N, L, D) tokens are stored as ne = [D, L, N, 1] + +#include +#include +#include +#include +#include +#include + #include "common_block.hpp" +#include "ggml_extend.hpp" +#include "model.h" +#include "rope.hpp" +#include "vae.hpp" namespace LTXV { + constexpr int LTXV_GRAPH_SIZE = 32768; + + // Debug probe registry: block forwards add intermediate tensors here; + // the Runner keeps them alive across compute and logs stats. + struct DebugProbes { + struct Entry { + std::string name; + ggml_tensor* tensor = nullptr; + }; + std::vector entries; + void add(const std::string& n, ggml_tensor* t) { + entries.push_back({n, t}); + } + void clear() { entries.clear(); } + }; + __STATIC_INLINE__ DebugProbes& debug_probes() { + static DebugProbes p; + return p; + } + + // 3-D depth-to-space (pixel-shuffle) matching einops + // rearrange(x, "b (c p1 p2 p3) f h w -> b c (f p1) (h p2) (w p3)") + // where the channel axis has structure (c outer, p1, p2, p3 inner). In + // ggml ne order the input is [W, H, F, C*p1*p2*p3] and the output is + // [W*p3, H*p2, F*p1, C]. Implemented as three separate passes that each + // peel one sub-axis off the channel, route it to its destination and + // merge it as the INNER sub-index — matching einops' conventions + // exactly (naive ggml_reshape_4d alone produces swapped sub-indices and + // causes the visible banding artefacts in decoded frames). + __STATIC_INLINE__ ggml_tensor* depth_to_space_3d(ggml_context* ctx, + ggml_tensor* x, + int p1, int p2, int p3) { + int64_t W = x->ne[0], H = x->ne[1], F = x->ne[2], Cb = x->ne[3]; + int64_t C = Cb / ((int64_t)p1 * p2 * p3); + GGML_ASSERT(C * p1 * p2 * p3 == Cb); + + // ---- pass p3: merge into W as inner sub-index ---------------- + if (p3 > 1) { + // Split p3 from channel into F*p3 (p3 outer within ne[2]). + x = ggml_reshape_4d(ctx, x, W, H, F * p3, C * p1 * p2); + // Isolate p3: ne=[W, H*F, p3, X]. + x = ggml_reshape_4d(ctx, x, W, H * F, p3, C * p1 * p2); + // Bring p3 innermost: [p3, W, H*F, X]. + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); + // Merge p3 with W (p3 inner, w outer) and restore H, F. + x = ggml_reshape_4d(ctx, x, p3 * W, H, F, C * p1 * p2); + W *= p3; + } + + // ---- pass p2: merge into H as inner sub-index ---------------- + if (p2 > 1) { + x = ggml_reshape_4d(ctx, x, W, H, F * p2, C * p1); + // Isolate p2: ne=[W*H, F, p2, X]. + x = ggml_reshape_4d(ctx, x, W * H, F, p2, C * p1); + // Bring p2 next to W*H: [W*H, p2, F, X]. + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); + // Split W*H → (W inner, H outer): ne=[W, H, p2, F*X]. + x = ggml_reshape_4d(ctx, x, W, H, p2, F * C * p1); + // Swap H ↔ p2 so that the next merge puts p2 inner of H. + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); + // Merge p2 and H (p2 inner, h outer) and restore F, C*p1. + x = ggml_reshape_4d(ctx, x, W, p2 * H, F, C * p1); + H *= p2; + } + + // ---- pass p1: merge into F as inner sub-index ---------------- + if (p1 > 1) { + x = ggml_reshape_4d(ctx, x, W, H, F * p1, C); + // Split F*p1 into separate F and p1 axes: ne=[W*H, F, p1, C]. + x = ggml_reshape_4d(ctx, x, W * H, F, p1, C); + // Swap so p1 is inner of the merged F*p1: [W*H, p1, F, C]. + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); + // Merge p1 with F (p1 inner, f outer) and restore W, H. + x = ggml_reshape_4d(ctx, x, W, H, p1 * F, C); + F *= p1; + } + + return x; + } + + // Patchify-convention 3-D depth-to-space used by the decoder's final + // unpatchify. Reference (ltx_core/model/video_vae/ops.py::unpatchify): + // rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=p_t, q=p_h, r=p_w) + // The channel axis is packed as (c outer, p_t, p_w middle, p_h inner). + // This is DIFFERENT from DepthToSpaceUpsample which uses (c, p_t, p_h, p_w) + // (p_w innermost). Using the wrong convention transposes every (p_h × p_w) + // output block and produces a visible fine-scale hatching artefact that + // survives every diffusion step. + __STATIC_INLINE__ ggml_tensor* depth_to_space_3d_patch(ggml_context* ctx, + ggml_tensor* x, + int p_t, int p_h, int p_w) { + int64_t W = x->ne[0], H = x->ne[1], F = x->ne[2], Cb = x->ne[3]; + int64_t C = Cb / ((int64_t)p_t * p_h * p_w); + GGML_ASSERT(C * p_t * p_h * p_w == Cb); + if (p_h == 1 && p_w == 1 && p_t == 1) { + return x; + } + if (p_h != 1 || p_w != 1) { + // Swap the inner (p_w, p_h) pair in the channel axis so the layout + // becomes (c, p_t, p_h, p_w) with p_w innermost — exactly what the + // DepthToSpaceUpsample convention (and therefore depth_to_space_3d) + // expects. Then the general helper can do the rest. + // Ne[3]=Cb is the slow axis; bring it to ne[0] to be able to split + // it into (p_h, p_w, C*p_t). + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 3, 0, 1, 2)); + // Reinterpret channel axis as [p_h (inner), p_w, C*p_t]. Flat + // order in channel is (c * p_t * p_w * p_h + p_t * p_w * p_h + p_w * p_h + p_h); + // the innermost (fast) sub-index is p_h, next is p_w, outer is C*p_t. + x = ggml_reshape_4d(ctx, x, p_h, p_w, C * p_t, W * H * F); + // Swap the first two dims so p_w becomes fastest. + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); + // Re-merge and put channel back to ne[3]: [W, H, F, Cb]. + x = ggml_reshape_4d(ctx, x, Cb, W, H, F); + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 3, 0)); + } + return depth_to_space_3d(ctx, x, p_t, p_h, p_w); + } + + // ================================================================= + // Shared primitives + // ================================================================= + + class RMSNormNoAffine : public UnaryBlock { + protected: + float eps; + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {} + + public: + RMSNormNoAffine(float eps = 1e-6f) : eps(eps) {} + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + return ggml_rms_norm(ctx->ggml_ctx, x, eps); + } + }; + + class PerChannelRMSNorm : public UnaryBlock { + protected: + float eps; + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {} + + public: + PerChannelRMSNorm(float eps = 1e-8f) : eps(eps) {} + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); + h = ggml_rms_norm(ctx->ggml_ctx, h, eps); + h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); + return h; + } + }; + + // Temporal-causal 3-D conv with runtime causal flag. + // Weight layout follows diffusers' LTX2VideoCausalConv3d (the raw nn.Conv3d + // is wrapped in `self.conv`, so tensor names are `.conv.weight`). class CausalConv3d : public GGMLBlock { protected: - int time_kernel_size; + int64_t in_channels; + int64_t out_channels; + std::tuple kernel_size; // (kt, kh, kw) + std::tuple stride; + std::tuple dilation; + bool bias; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + // ggml_cuda_op_im2col_3d only supports F16/F32 destination tensors + // — BF16 weights (the native LTX-2.3 dtype) would trigger its + // GGML_ASSERT. Force F16 here so sd.cpp's loader converts BF16 + // from the checkpoint on its way in. F32 was tested and gave + // identical output scale, so F16 is safe. + params["conv.weight"] = ggml_new_tensor_4d(ctx, + GGML_TYPE_F16, + std::get<2>(kernel_size), + std::get<1>(kernel_size), + std::get<0>(kernel_size), + in_channels * out_channels); + if (bias) { + params["conv.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + } + } public: CausalConv3d(int64_t in_channels, int64_t out_channels, - int kernel_size = 3, - std::tuple stride = {1, 1, 1}, - int dilation = 1, - bool bias = true) { - time_kernel_size = kernel_size / 2; - blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, - out_channels, - {kernel_size, kernel_size, kernel_size}, - stride, - {0, kernel_size / 2, kernel_size / 2}, - {dilation, 1, 1}, - bias)); + std::tuple kernel_size, + std::tuple stride = {1, 1, 1}, + std::tuple dilation = {1, 1, 1}, + bool bias = true) + : in_channels(in_channels), + out_channels(out_channels), + kernel_size(kernel_size), + stride(stride), + dilation(dilation), + bias(bias) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + ggml_tensor* w = params["conv.weight"]; + ggml_tensor* b = bias ? params["conv.bias"] : nullptr; + + int kt = std::get<0>(kernel_size); + int kh = std::get<1>(kernel_size); + int kw = std::get<2>(kernel_size); + + if (kt > 1) { + if (causal) { + auto first = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + auto pad_left = first; + for (int i = 1; i < kt - 1; ++i) { + pad_left = ggml_concat(ctx->ggml_ctx, pad_left, first, 2); + } + x = ggml_concat(ctx->ggml_ctx, pad_left, x, 2); + } else { + int half = (kt - 1) / 2; + if (half > 0) { + auto first = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], 0); + auto last = ggml_view_4d(ctx->ggml_ctx, x, + x->ne[0], x->ne[1], 1, x->ne[3], + x->nb[1], x->nb[2], x->nb[3], + x->nb[2] * (x->ne[2] - 1)); + auto pad_left = first; + for (int i = 1; i < half; ++i) { + pad_left = ggml_concat(ctx->ggml_ctx, pad_left, first, 2); + } + auto pad_right = last; + for (int i = 1; i < half; ++i) { + pad_right = ggml_concat(ctx->ggml_ctx, pad_right, last, 2); + } + x = ggml_concat(ctx->ggml_ctx, pad_left, x, 2); + x = ggml_concat(ctx->ggml_ctx, x, pad_right, 2); + } + } + } + + int lp_w = kw / 2, rp_w = kw / 2; + int lp_h = kh / 2, rp_h = kh / 2; + x = ggml_ext_pad_ext(ctx->ggml_ctx, x, lp_w, rp_w, lp_h, rp_h, 0, 0, 0, 0, + ctx->circular_x_enabled, ctx->circular_y_enabled); + + return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, + std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), + 0, 0, 0, + std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); + } + }; + + // ================================================================= + // Transformer primitives + // ================================================================= + + class TimestepEmbedderSingle : public GGMLBlock { + protected: + int64_t frequency_embedding_size; + + public: + TimestepEmbedderSingle(int64_t hidden_size, int64_t frequency_embedding_size = 256) + : frequency_embedding_size(frequency_embedding_size) { + blocks["linear_1"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size, true)); + blocks["linear_2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* t) { + auto l1 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto l2 = std::dynamic_pointer_cast(blocks["linear_2"]); + auto f = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size); + f = l1->forward(ctx, f); + f = ggml_silu_inplace(ctx->ggml_ctx, f); + f = l2->forward(ctx, f); + return f; + } + }; + + class AdaLayerNormSingle : public GGMLBlock { + public: + int64_t hidden_size; + int64_t num_mod_params; + + AdaLayerNormSingle(int64_t hidden_size, int64_t num_mod_params) + : hidden_size(hidden_size), num_mod_params(num_mod_params) { + blocks["emb.timestep_embedder"] = + std::shared_ptr(new TimestepEmbedderSingle(hidden_size)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, num_mod_params * hidden_size, true)); + } + + std::pair forward(GGMLRunnerContext* ctx, ggml_tensor* t) { + auto emb = std::dynamic_pointer_cast(blocks["emb.timestep_embedder"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto embedded = emb->forward(ctx, t); + auto x = ggml_silu(ctx->ggml_ctx, embedded); + auto temb = linear->forward(ctx, x); + return {temb, embedded}; + } + }; + + class FeedForward : public GGMLBlock { + public: + FeedForward(int64_t dim, int64_t inner_dim = -1) { + if (inner_dim < 0) inner_dim = dim * 4; + blocks["net.0.proj"] = std::shared_ptr(new Linear(dim, inner_dim, true)); + blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim, true)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto fc1 = std::dynamic_pointer_cast(blocks["net.0.proj"]); + auto fc2 = std::dynamic_pointer_cast(blocks["net.2"]); + x = fc1->forward(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = fc2->forward(ctx, x); + return x; + } + }; + + // LTX-2.3 attention: gated, qk_norm_across_heads, split or interleaved RoPE. + // Parameters: to_q, to_k, to_v, to_out.0, q_norm, k_norm, to_gate_logits. + // rope_type selects between the two rotation layouts used by LTX-2.3: + // * "interleaved": pair indices (2k, 2k+1) rotate together + // * "split": pair indices (k, k+r) rotate together (r = D/2) + // LTX-2.3 22B uses `rope_type = "split"`. + class LTXAttention : public GGMLBlock { + public: + int64_t query_dim; + int64_t inner_dim; + int64_t kv_inner_dim; + int64_t num_heads; + int64_t head_dim; + bool has_rope; + std::string rope_type; + + LTXAttention(int64_t query_dim, + int64_t heads, + int64_t dim_head, + int64_t cross_attention_dim = -1, + bool attention_bias = true, + bool attention_out_bias = true, + bool apply_rope = true, + int64_t kv_heads = -1, + int64_t kv_dim_head = -1, + std::string rope_type = "split") + : query_dim(query_dim), + num_heads(heads), + head_dim(dim_head), + has_rope(apply_rope && cross_attention_dim < 0), + rope_type(rope_type) { + inner_dim = heads * dim_head; + if (kv_heads < 0) kv_heads = heads; + if (kv_dim_head < 0) kv_dim_head = dim_head; + kv_inner_dim = kv_heads * kv_dim_head; + int64_t kv_source_dim = (cross_attention_dim > 0) ? cross_attention_dim : query_dim; + + blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, attention_bias)); + blocks["to_k"] = std::shared_ptr(new Linear(kv_source_dim, kv_inner_dim, attention_bias)); + blocks["to_v"] = std::shared_ptr(new Linear(kv_source_dim, kv_inner_dim, attention_bias)); + blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, query_dim, attention_out_bias)); + blocks["q_norm"] = std::shared_ptr(new RMSNorm(inner_dim, 1e-6f)); + blocks["k_norm"] = std::shared_ptr(new RMSNorm(kv_inner_dim, 1e-6f)); + blocks["to_gate_logits"] = std::shared_ptr(new Linear(query_dim, heads, true)); } ggml_tensor* forward(GGMLRunnerContext* ctx, - ggml_tensor* x, - bool causal = true) { - // x: [N*IC, ID, IH, IW] - // result: [N*OC, OD, OH, OW] - auto conv = std::dynamic_pointer_cast(blocks["conv"]); - if (causal) { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < time_kernel_size - 1; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } - x = ggml_concat(ctx, first_frame_pad, x, 2); + ggml_tensor* hidden_states, + ggml_tensor* encoder_hidden_states = nullptr, + ggml_tensor* query_rope_cos = nullptr, + ggml_tensor* query_rope_sin = nullptr, + ggml_tensor* key_rope_cos = nullptr, + ggml_tensor* key_rope_sin = nullptr, + ggml_tensor* attention_mask = nullptr, + const char* probe_prefix = nullptr) { + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto to_out = std::dynamic_pointer_cast(blocks["to_out.0"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto gate = std::dynamic_pointer_cast(blocks["to_gate_logits"]); + + auto probe_attn = [&](const char* suffix, ggml_tensor* t) { + if (!probe_prefix) return; + std::string full = std::string(probe_prefix) + "_" + suffix; + auto dup = ggml_dup(ctx->ggml_ctx, t); + ggml_set_name(dup, full.c_str()); + debug_probes().add(full, dup); + }; + + ggml_tensor* kv_src = encoder_hidden_states != nullptr ? encoder_hidden_states : hidden_states; + probe_attn("kv_src", kv_src); + probe_attn("q_src", hidden_states); + + auto gate_logits = gate->forward(ctx, hidden_states); + probe_attn("gate_logits", gate_logits); + + auto q = to_q->forward(ctx, hidden_states); + auto k = to_k->forward(ctx, kv_src); + auto v = to_v->forward(ctx, kv_src); + probe_attn("q_proj", q); + probe_attn("k_proj", k); + probe_attn("v_proj", v); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + probe_attn("q_norm", q); + probe_attn("k_norm", k); + + if (has_rope && query_rope_cos != nullptr && query_rope_sin != nullptr) { + if (rope_type == "split") { + q = apply_split_rotary_emb(ctx, q, query_rope_cos, query_rope_sin, num_heads); + ggml_tensor* kc = key_rope_cos != nullptr ? key_rope_cos : query_rope_cos; + ggml_tensor* ks = key_rope_sin != nullptr ? key_rope_sin : query_rope_sin; + k = apply_split_rotary_emb(ctx, k, kc, ks, num_heads); + } else { + q = apply_rotary_emb(ctx, q, query_rope_cos, query_rope_sin); + ggml_tensor* kc = key_rope_cos != nullptr ? key_rope_cos : query_rope_cos; + ggml_tensor* ks = key_rope_sin != nullptr ? key_rope_sin : query_rope_sin; + k = apply_rotary_emb(ctx, k, kc, ks); + } + } + + auto out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, + num_heads, attention_mask, false, + ctx->flash_attn_enabled); + probe_attn("raw_attn_out", out); + + // Per-head gate: gates = 2 * sigmoid(gate_logits). Broadcast + // [heads, L_q, N] over head_dim via reshape to [1, heads, L_q, N]. + { + auto gates = ggml_sigmoid(ctx->ggml_ctx, gate_logits); + gates = ggml_scale(ctx->ggml_ctx, gates, 2.0f); + int64_t N = out->ne[2]; + int64_t L_q = out->ne[1]; + auto out_4d = ggml_reshape_4d(ctx->ggml_ctx, out, head_dim, num_heads, L_q, N); + auto gates_4d = ggml_reshape_4d(ctx->ggml_ctx, gates, 1, num_heads, L_q, N); + out_4d = ggml_mul(ctx->ggml_ctx, out_4d, gates_4d); + out = ggml_reshape_3d(ctx->ggml_ctx, out_4d, inner_dim, L_q, N); + } + probe_attn("after_gate", out); + + out = to_out->forward(ctx, out); + probe_attn("to_out", out); + return out; + } + + static ggml_tensor* apply_rotary_emb(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* cos_freqs, + ggml_tensor* sin_freqs) { + int64_t C = x->ne[0]; + int64_t L = x->ne[1]; + int64_t N = x->ne[2]; + + auto x4 = ggml_reshape_4d(ctx->ggml_ctx, x, 2, C / 2, L, N); + auto real = ggml_view_4d(ctx->ggml_ctx, x4, 1, C / 2, L, N, + x4->nb[1], x4->nb[2], x4->nb[3], 0); + auto imag = ggml_view_4d(ctx->ggml_ctx, x4, 1, C / 2, L, N, + x4->nb[1], x4->nb[2], x4->nb[3], x4->nb[0]); + auto real_c = ggml_cont(ctx->ggml_ctx, real); + auto imag_c = ggml_cont(ctx->ggml_ctx, imag); + auto neg_imag = ggml_neg(ctx->ggml_ctx, imag_c); + auto rotated = ggml_concat(ctx->ggml_ctx, neg_imag, real_c, 0); + rotated = ggml_reshape_4d(ctx->ggml_ctx, rotated, C, L, N, 1); + + auto x_cos = ggml_mul(ctx->ggml_ctx, x, cos_freqs); + auto x_sin = ggml_mul(ctx->ggml_ctx, rotated, sin_freqs); + return ggml_add(ctx->ggml_ctx, x_cos, x_sin); + } + + // Split-rope: pair is (x[k], x[k+r]) where r = D_per_head/2. + // In diffusers: x.reshape(..., 2, r), [first, second] = x.unbind(-2) + // first_new = first * cos - second * sin + // second_new = second * cos + first * sin + // reshape back. + // + // cos_freqs / sin_freqs are [inner_dim/2, L] tensors in our layout; + // we reshape them per head to [head_dim/2, L] via broadcast. + static ggml_tensor* apply_split_rotary_emb(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* cos_freqs, + ggml_tensor* sin_freqs, + int64_t num_heads) { + int64_t C = x->ne[0]; // inner_dim + int64_t L = x->ne[1]; + int64_t N = x->ne[2]; + int64_t D = C / num_heads; // head_dim + int64_t r = D / 2; + + // Reshape x from [C, L, N] to [r, 2, num_heads, L*N] so the last dim + // of the pair (the "2" axis) is at ggml axis 1. + auto x4 = ggml_reshape_4d(ctx->ggml_ctx, x, r, 2, num_heads, L * N); + // first = x4[:, 0, :, :] (ne = [r, 1, num_heads, L*N]) + // second = x4[:, 1, :, :] + auto first = ggml_view_4d(ctx->ggml_ctx, x4, r, 1, num_heads, L * N, + x4->nb[1], x4->nb[2], x4->nb[3], 0); + auto second = ggml_view_4d(ctx->ggml_ctx, x4, r, 1, num_heads, L * N, + x4->nb[1], x4->nb[2], x4->nb[3], x4->nb[1]); + first = ggml_cont(ctx->ggml_ctx, first); + second = ggml_cont(ctx->ggml_ctx, second); + + // cos/sin are [inner_dim/2, L] == [num_heads*r, L]. Reshape to + // [r, 1, num_heads, L] so they broadcast over the batch axis (L*N/L). + auto cos_v = ggml_reshape_4d(ctx->ggml_ctx, cos_freqs, r, 1, num_heads, L); + auto sin_v = ggml_reshape_4d(ctx->ggml_ctx, sin_freqs, r, 1, num_heads, L); + + // first_new = first * cos - second * sin + // second_new = second * cos + first * sin + auto first_new = ggml_sub(ctx->ggml_ctx, + ggml_mul(ctx->ggml_ctx, first, cos_v), + ggml_mul(ctx->ggml_ctx, second, sin_v)); + auto second_new = ggml_add(ctx->ggml_ctx, + ggml_mul(ctx->ggml_ctx, second, cos_v), + ggml_mul(ctx->ggml_ctx, first, sin_v)); + + // Stack back along axis 1: [r, 2, num_heads, L*N] and reshape to [C, L, N]. + auto out = ggml_concat(ctx->ggml_ctx, first_new, second_new, 1); + out = ggml_reshape_3d(ctx->ggml_ctx, out, C, L, N); + return out; + } + }; + + // EmbeddingsConnector's internal transformer_1d_blocks: attn1 + ff with + // PRE-NORM (stateless rms_norm) before each op. The reference is + // Lightricks' Embeddings1DConnector._BasicTransformerBlock1D: it calls + // `rms_norm(h)` before attn1 and before ff; residuals add the un-normed + // input back. Without the pre-norms, residual magnitudes compound across + // the 8 blocks and drive the connector output to ~1e12. + class EmbeddingsConnectorBlock : public GGMLBlock { + public: + int64_t dim; + + EmbeddingsConnectorBlock(int64_t dim, + int64_t num_attention_heads, + int64_t attention_head_dim) : dim(dim) { + blocks["attn1"] = std::shared_ptr(new LTXAttention( + dim, num_attention_heads, attention_head_dim, /*cross=*/-1, true, true, /*apply_rope=*/false)); + blocks["ff"] = std::shared_ptr(new FeedForward(dim, 4 * dim)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + auto xn = ggml_rms_norm(ctx->ggml_ctx, x, 1e-6f); + auto a = attn1->forward(ctx, xn); + x = ggml_add(ctx->ggml_ctx, x, a); + auto xn2 = ggml_rms_norm(ctx->ggml_ctx, x, 1e-6f); + auto f = ff->forward(ctx, xn2); + x = ggml_add(ctx->ggml_ctx, x, f); + return x; + } + }; + + // EmbeddingsConnector — LTX-2.3 prompt re-embedder. + // 128 learnable registers prepended to the projected text embeddings, then + // passed through a stack of self-attention + FF blocks. + class EmbeddingsConnector : public GGMLBlock { + public: + int64_t dim; + int64_t num_registers; + int64_t num_blocks; + + protected: + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + params["learnable_registers"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, num_registers); + } + + public: + EmbeddingsConnector(int64_t dim, + int64_t num_attention_heads, + int64_t attention_head_dim, + int64_t num_registers = 128, + int64_t num_blocks = 8) + : dim(dim), num_registers(num_registers), num_blocks(num_blocks) { + for (int64_t i = 0; i < num_blocks; ++i) { + blocks["transformer_1d_blocks." + std::to_string(i)] = + std::shared_ptr(new EmbeddingsConnectorBlock( + dim, num_attention_heads, attention_head_dim)); + } + } + + // text_embeddings: [dim, L, N, 1] + // Output: [dim, num_registers, N, 1] (ALWAYS 128 tokens) + // + // Reference: LTX-2 Embeddings1DConnector._replace_padded_with_learnable_registers. + // The input is assumed LEFT-padded to num_registers tokens; real + // text sits at the END. The connector flips the mask and writes + // real text into positions [0..L-1] and learnable_registers[L..R-1] + // into positions [L..R-1]. Sequence length is FIXED at num_registers. + // + // Our caller passes the real text with L ≤ num_registers tokens. + // We implement the reference's math directly: + // out[:L] = text + // out[L:R] = learnable_registers[L:R] + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* text_embeddings) { + ggml_tensor* reg = params["learnable_registers"]; // [dim, num_registers] + int64_t D = text_embeddings->ne[0]; + int64_t L = text_embeddings->ne[1]; + int64_t N = text_embeddings->ne[2]; + GGML_ASSERT(L <= num_registers); + + ggml_tensor* x; + if (L == num_registers) { + // No padding needed — use text directly. + x = text_embeddings; } else { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - int64_t offset = h->nb[2] * h->ne[2]; + // Slice learnable_registers[L..R] = [dim, R-L]. + auto reg_slice = ggml_view_2d(ctx->ggml_ctx, reg, + D, num_registers - L, + reg->nb[1], reg->nb[1] * L); + reg_slice = ggml_cont(ctx->ggml_ctx, reg_slice); + // Reshape to [dim, R-L, 1] and broadcast across N if needed. + auto reg_3d = ggml_reshape_3d(ctx->ggml_ctx, reg_slice, + D, num_registers - L, 1); + if (N != 1) { + auto target = ggml_new_tensor_3d(ctx->ggml_ctx, reg_3d->type, + D, num_registers - L, N); + reg_3d = ggml_repeat(ctx->ggml_ctx, reg_3d, target); + } + // Concatenate text FIRST then registers — matches the + // reference output layout [text(L), registers(L..R)]. + x = ggml_concat(ctx->ggml_ctx, text_embeddings, reg_3d, 1); + } + + for (int64_t i = 0; i < num_blocks; ++i) { + auto b = std::dynamic_pointer_cast( + blocks["transformer_1d_blocks." + std::to_string(i)]); + x = b->forward(ctx, x); + } + // Final stateless rms_norm (matches reference). + x = ggml_rms_norm(ctx->ggml_ctx, x, 1e-6f); + return x; + } + }; + + // Transformer block for LTX-2.3 (video-only forward). + // Every weight slot in transformer_blocks.N is registered: + // attn1, attn2, audio_attn1, audio_attn2 (all gated, qk_norm) + // audio_to_video_attn, video_to_audio_attn (gated, no rope) + // ff, audio_ff + // scale_shift_table [dim, 9] + // audio_scale_shift_table [audio_dim, 9] + // prompt_scale_shift_table [dim, 2] + // audio_prompt_scale_shift_table [audio_dim, 2] + // scale_shift_table_a2v_ca_video [dim, 5] + // scale_shift_table_a2v_ca_audio [audio_dim, 5] + class LTX2VideoTransformerBlock : public GGMLBlock { + protected: + int64_t dim; + int64_t audio_dim; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 9); + params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, 9); + params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2); + params["audio_prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, 2); + params["scale_shift_table_a2v_ca_video"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 5); + params["scale_shift_table_a2v_ca_audio"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_dim, 5); + } + + public: + LTX2VideoTransformerBlock(int64_t dim, + int64_t num_attention_heads, + int64_t attention_head_dim, + int64_t cross_attention_dim, + int64_t audio_dim, + int64_t audio_num_attention_heads, + int64_t audio_attention_head_dim, + int64_t audio_cross_attention_dim, + float eps = 1e-6f) + : dim(dim), audio_dim(audio_dim) { + blocks["attn1"] = std::shared_ptr(new LTXAttention( + dim, num_attention_heads, attention_head_dim)); + blocks["attn2"] = std::shared_ptr(new LTXAttention( + dim, num_attention_heads, attention_head_dim, cross_attention_dim, true, true, false)); + blocks["audio_attn1"] = std::shared_ptr(new LTXAttention( + audio_dim, audio_num_attention_heads, audio_attention_head_dim)); + blocks["audio_attn2"] = std::shared_ptr(new LTXAttention( + audio_dim, audio_num_attention_heads, audio_attention_head_dim, + audio_cross_attention_dim, true, true, false)); - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); + // Cross-modal attention — query_dim from target modality, kv from source. + blocks["audio_to_video_attn"] = std::shared_ptr(new LTXAttention( + dim, audio_num_attention_heads, audio_attention_head_dim, + audio_dim, true, true, false, + audio_num_attention_heads, audio_attention_head_dim)); + blocks["video_to_audio_attn"] = std::shared_ptr(new LTXAttention( + audio_dim, audio_num_attention_heads, audio_attention_head_dim, + dim, true, true, false, + audio_num_attention_heads, audio_attention_head_dim)); + + blocks["ff"] = std::shared_ptr(new FeedForward(dim, 4 * dim)); + blocks["audio_ff"] = std::shared_ptr(new FeedForward(audio_dim, 4 * audio_dim)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden, + ggml_tensor* encoder, + ggml_tensor* temb, + ggml_tensor* rope_cos = nullptr, + ggml_tensor* rope_sin = nullptr, + ggml_tensor* encoder_mask = nullptr, + int block_idx = -1) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + auto probe_tensor = [&](const char* name, ggml_tensor* t) { + if (block_idx == 0) { + auto dup = ggml_dup(ctx->ggml_ctx, t); + ggml_set_name(dup, name); + debug_probes().add(name, dup); } + }; + + ggml_tensor* sst = params["scale_shift_table"]; // [dim, 9] + auto temb_r = ggml_reshape_4d(ctx->ggml_ctx, temb, dim, 9, temb->ne[1], temb->ne[2]); + auto ada = ggml_add(ctx->ggml_ctx, temb_r, sst); + + auto slice = [&](int idx) { + auto v = ggml_view_4d(ctx->ggml_ctx, ada, ada->ne[0], 1, ada->ne[2], ada->ne[3], + ada->nb[1], ada->nb[2], ada->nb[3], ada->nb[1] * idx); + return ggml_reshape_3d(ctx->ggml_ctx, v, ada->ne[0], ada->ne[2], ada->ne[3]); + }; + auto shift_msa = slice(0); + auto scale_msa = slice(1); + auto gate_msa = slice(2); + auto shift_mlp = slice(3); + auto scale_mlp = slice(4); + auto gate_mlp = slice(5); + auto shift_text_q = slice(6); + auto scale_text_q = slice(7); + auto gate_text_q = slice(8); + + const char* dbg_mode = std::getenv("LTXV_DEBUG_MODE"); + bool skip_mod = dbg_mode && std::strstr(dbg_mode, "no_mod"); + bool skip_attn1 = dbg_mode && std::strstr(dbg_mode, "no_attn1"); + bool skip_attn2 = dbg_mode && std::strstr(dbg_mode, "no_attn2"); + bool skip_ff = dbg_mode && std::strstr(dbg_mode, "no_ff"); + bool skip_scale = dbg_mode && std::strstr(dbg_mode, "no_scale"); + bool skip_shift = dbg_mode && std::strstr(dbg_mode, "no_shift"); + bool skip_gate = dbg_mode && std::strstr(dbg_mode, "no_gate"); + bool ret_h_norm1 = dbg_mode && std::strstr(dbg_mode, "ret=h_norm1"); + bool ret_scale_msa = dbg_mode && std::strstr(dbg_mode, "ret=scale_msa"); + bool ret_attn1_out = dbg_mode && std::strstr(dbg_mode, "ret=attn1_out"); + + if (ret_scale_msa) { + // Broadcast scale_msa to hidden shape so the caller's reshape works. + // scale_msa is [dim, T_temb, N]; hidden is [dim, L, N]. Broadcast + // the first axis with repeat. + auto target = ggml_new_tensor_3d(ctx->ggml_ctx, scale_msa->type, + scale_msa->ne[0], hidden->ne[1], scale_msa->ne[2]); + return ggml_repeat(ctx->ggml_ctx, scale_msa, target); + } + + probe_tensor("blk0_hidden_in", hidden); + probe_tensor("blk0_encoder_in", encoder); + probe_tensor("blk0_scale_msa", scale_msa); + probe_tensor("blk0_shift_msa", shift_msa); + probe_tensor("blk0_gate_msa", gate_msa); + probe_tensor("blk0_scale_text_q", scale_text_q); + probe_tensor("blk0_shift_text_q", shift_text_q); + probe_tensor("blk0_gate_text_q", gate_text_q); - auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW] - last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW] - auto last_frame_pad = last_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2); + // 1. Video self-attention + auto h_norm = ggml_rms_norm(ctx->ggml_ctx, hidden, 1e-6f); + probe_tensor("blk0_after_norm1", h_norm); + if (!skip_mod) { + if (!skip_scale) { + h_norm = ggml_add(ctx->ggml_ctx, h_norm, ggml_mul(ctx->ggml_ctx, h_norm, scale_msa)); } + if (!skip_shift) { + h_norm = ggml_add(ctx->ggml_ctx, h_norm, shift_msa); + } + } + probe_tensor("blk0_after_mod1", h_norm); + if (ret_h_norm1) { + return h_norm; + } + if (!skip_attn1) { + auto attn_out = attn1->forward(ctx, h_norm, nullptr, + rope_cos, rope_sin, nullptr, nullptr, nullptr); + probe_tensor("blk0_after_attn1", attn_out); + if (ret_attn1_out) { + return attn_out; + } + if (!skip_mod && !skip_gate) { + hidden = ggml_add(ctx->ggml_ctx, hidden, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + } else { + hidden = ggml_add(ctx->ggml_ctx, hidden, attn_out); + } + probe_tensor("blk0_after_attn1_residual", hidden); + } + + // 2. Prompt cross-attention with Q modulation + auto h_norm2 = ggml_rms_norm(ctx->ggml_ctx, hidden, 1e-6f); + probe_tensor("blk0_after_norm2", h_norm2); + if (!skip_mod) { + h_norm2 = ggml_add(ctx->ggml_ctx, h_norm2, ggml_mul(ctx->ggml_ctx, h_norm2, scale_text_q)); + h_norm2 = ggml_add(ctx->ggml_ctx, h_norm2, shift_text_q); + } + probe_tensor("blk0_after_mod2", h_norm2); + if (!skip_attn2) { + const char* attn2_prefix = (block_idx == 0) ? "blk0_attn2" : nullptr; + auto ca_out = attn2->forward(ctx, h_norm2, encoder, + nullptr, nullptr, nullptr, nullptr, encoder_mask, + attn2_prefix); + probe_tensor("blk0_after_attn2", ca_out); + if (!skip_mod) { + ca_out = ggml_mul(ctx->ggml_ctx, ca_out, gate_text_q); + } + hidden = ggml_add(ctx->ggml_ctx, hidden, ca_out); + probe_tensor("blk0_after_attn2_residual", hidden); + } + + // 3. a2v/v2a cross-attention — SKIPPED (video-only mode). + + // 4. FFN + auto h_norm3 = ggml_rms_norm(ctx->ggml_ctx, hidden, 1e-6f); + probe_tensor("blk0_after_norm3", h_norm3); + if (!skip_mod) { + h_norm3 = ggml_add(ctx->ggml_ctx, h_norm3, ggml_mul(ctx->ggml_ctx, h_norm3, scale_mlp)); + h_norm3 = ggml_add(ctx->ggml_ctx, h_norm3, shift_mlp); + } + probe_tensor("blk0_after_mod3", h_norm3); + if (!skip_ff) { + auto ff_out = ff->forward(ctx, h_norm3); + probe_tensor("blk0_after_ff", ff_out); + if (!skip_mod) { + hidden = ggml_add(ctx->ggml_ctx, hidden, ggml_mul(ctx->ggml_ctx, ff_out, gate_mlp)); + } else { + hidden = ggml_add(ctx->ggml_ctx, hidden, ff_out); + } + probe_tensor("blk0_after_ff_residual", hidden); + } + return hidden; + } + }; + + // ================================================================= + // 3-D RoPE (interleaved) + // ================================================================= + + struct RopeTables { + std::vector cos; + std::vector sin; + int64_t L = 0; + int64_t dim = 0; + }; + + __STATIC_INLINE__ RopeTables compute_rope_ltx2(int num_frames, + int height, + int width, + int dim, + bool split_rope = true, + int patch_size = 1, + int patch_size_t = 1, + int base_frames = 20, + int base_h = 2048, + int base_w = 2048, + int vae_scale_t = 8, + int vae_scale_h = 32, + int vae_scale_w = 32, + int causal_offset = 1, + float fps = 24.f, + float theta = 10000.f) { + // Split-layout : cos/sin of size dim/2 per position (no duplication). + // Interleaved : cos/sin of size dim per position (repeat_interleave(2)). + RopeTables t; + int64_t pos_dim = split_rope ? (int64_t)(dim / 2) : (int64_t)dim; + t.dim = pos_dim; + t.L = (int64_t)num_frames * height * width; + t.cos.assign(t.L * pos_dim, 0.f); + t.sin.assign(t.L * pos_dim, 0.f); + + // Split: 3 pos-axes, dim/2 total freq slots → freq_per_axis = (dim/2) / 3 + // Interleaved: 6 rope elems, dim/6 per axis. + int num_axes = 3; + int slots = (int)pos_dim; // total per-position storage size + int freq_per_axis = slots / num_axes; + int pad = slots - num_axes * freq_per_axis; + + std::vector freqs(freq_per_axis); + if (freq_per_axis > 1) { + for (int i = 0; i < freq_per_axis; ++i) { + float exponent = (float)i / (float)(freq_per_axis - 1); + freqs[i] = std::pow(theta, exponent) * (float)M_PI / 2.f; + } + } else if (freq_per_axis == 1) { + freqs[0] = (float)M_PI / 2.f; + } + + int64_t idx = 0; + for (int f = 0; f < num_frames; ++f) { + float pix_start_t = (float)f * patch_size_t * vae_scale_t; + float pix_end_t = ((float)f * patch_size_t + patch_size_t) * vae_scale_t; + pix_start_t = std::max(0.f, pix_start_t + (float)causal_offset - (float)vae_scale_t); + pix_end_t = std::max(0.f, pix_end_t + (float)causal_offset - (float)vae_scale_t); + float mid_t = 0.5f * (pix_start_t + pix_end_t) / fps; + float gf = mid_t / (float)base_frames; + + for (int h = 0; h < height; ++h) { + float mid_h = ((float)h + 0.5f) * (float)patch_size * (float)vae_scale_h; + float gh = mid_h / (float)base_h; + for (int w = 0; w < width; ++w) { + float mid_w = ((float)w + 0.5f) * (float)patch_size * (float)vae_scale_w; + float gw = mid_w / (float)base_w; + float* co = &t.cos[idx * pos_dim]; + float* si = &t.sin[idx * pos_dim]; + + // Leading pad: cos=1, sin=0. For LTX-2.3 22B with dim=4096 split, + // pad_size = 2048 - 3 * (2048/3) = 2 (matches diffusers). + for (int p = 0; p < pad; ++p) { + co[p] = 1.f; + si[p] = 0.f; + } + for (int k = 0; k < freq_per_axis; ++k) { + float ang_f = freqs[k] * (gf * 2.f - 1.f); + float ang_h = freqs[k] * (gh * 2.f - 1.f); + float ang_w = freqs[k] * (gw * 2.f - 1.f); + float vals[3] = {ang_f, ang_h, ang_w}; + if (split_rope) { + // Layout: per-position, values = [pad, (F0,H0,W0), (F1,H1,W1), ...] + for (int a = 0; a < 3; ++a) { + co[pad + k * 3 + a] = std::cos(vals[a]); + si[pad + k * 3 + a] = std::sin(vals[a]); + } + } else { + // Interleaved layout: each (ang) expands to (cos, cos) / (sin, sin). + for (int a = 0; a < 3; ++a) { + float c = std::cos(vals[a]); + float s = std::sin(vals[a]); + co[pad + 2 * (k * 3 + a) + 0] = c; + co[pad + 2 * (k * 3 + a) + 1] = c; + si[pad + 2 * (k * 3 + a) + 0] = s; + si[pad + 2 * (k * 3 + a) + 1] = s; + } + } + } + ++idx; + } + } + } + return t; + } + + // ================================================================= + // Full transformer + // ================================================================= + + class LTX2VideoTransformer3DModel : public GGMLBlock { + public: + int64_t in_channels; + int64_t out_channels; + int64_t num_layers; + int64_t num_attention_heads; + int64_t attention_head_dim; + int64_t inner_dim; + int64_t audio_inner_dim; + int64_t audio_num_attention_heads; + int64_t audio_attention_head_dim; + int64_t cross_attention_dim; + int64_t caption_channels; + int64_t audio_cross_attention_dim; + int64_t audio_in_channels; + int64_t audio_out_channels; + int64_t connector_num_registers; + int64_t connector_num_blocks; + int patch_size; + int patch_size_t; + + protected: + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, inner_dim, 2); + params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, audio_inner_dim, 2); + } + + public: + LTX2VideoTransformer3DModel(int64_t in_channels = 128, + int64_t out_channels = 128, + int patch_size = 1, + int patch_size_t = 1, + int64_t num_attention_heads = 32, + int64_t attention_head_dim = 128, + int64_t cross_attention_dim = 4096, + int64_t num_layers = 48, + int64_t caption_channels = 4096, + int64_t audio_in_channels = 128, + int64_t audio_out_channels = 128, + int64_t audio_num_attention_heads = 32, + int64_t audio_attention_head_dim = 64, + int64_t audio_cross_attention_dim = 2048, + int64_t connector_num_registers = 128, + int64_t connector_num_blocks = 8) + : in_channels(in_channels), + out_channels(out_channels), + num_layers(num_layers), + num_attention_heads(num_attention_heads), + attention_head_dim(attention_head_dim), + cross_attention_dim(cross_attention_dim), + caption_channels(caption_channels), + audio_cross_attention_dim(audio_cross_attention_dim), + audio_in_channels(audio_in_channels), + audio_out_channels(audio_out_channels), + audio_num_attention_heads(audio_num_attention_heads), + audio_attention_head_dim(audio_attention_head_dim), + connector_num_registers(connector_num_registers), + connector_num_blocks(connector_num_blocks), + patch_size(patch_size), + patch_size_t(patch_size_t) { + inner_dim = num_attention_heads * attention_head_dim; + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim; + + // Force F32 on patchify weights: the combination of tiny in_channels + // (128) and BF16 storage triggers a matmul pathway that gives wildly + // wrong magnitudes on some ggml backends (observed 6e9x explosion). + blocks["patchify_proj"] = std::shared_ptr(new Linear(in_channels, inner_dim, true, /*force_f32=*/true)); + blocks["audio_patchify_proj"] = std::shared_ptr(new Linear(audio_in_channels, audio_inner_dim, true, /*force_f32=*/true)); + + blocks["adaln_single"] = std::shared_ptr(new AdaLayerNormSingle(inner_dim, 9)); + blocks["audio_adaln_single"] = std::shared_ptr(new AdaLayerNormSingle(audio_inner_dim, 9)); + + blocks["prompt_adaln_single"] = std::shared_ptr(new AdaLayerNormSingle(inner_dim, 2)); + blocks["audio_prompt_adaln_single"] = std::shared_ptr(new AdaLayerNormSingle(audio_inner_dim, 2)); + + blocks["av_ca_video_scale_shift_adaln_single"] = + std::shared_ptr(new AdaLayerNormSingle(inner_dim, 4)); + blocks["av_ca_audio_scale_shift_adaln_single"] = + std::shared_ptr(new AdaLayerNormSingle(audio_inner_dim, 4)); + blocks["av_ca_a2v_gate_adaln_single"] = + std::shared_ptr(new AdaLayerNormSingle(inner_dim, 1)); + blocks["av_ca_v2a_gate_adaln_single"] = + std::shared_ptr(new AdaLayerNormSingle(audio_inner_dim, 1)); + + blocks["video_embeddings_connector"] = std::shared_ptr(new EmbeddingsConnector( + inner_dim, num_attention_heads, attention_head_dim, + connector_num_registers, connector_num_blocks)); + blocks["audio_embeddings_connector"] = std::shared_ptr(new EmbeddingsConnector( + audio_inner_dim, audio_num_attention_heads, audio_attention_head_dim, + connector_num_registers, connector_num_blocks)); + + for (int64_t i = 0; i < num_layers; ++i) { + blocks["transformer_blocks." + std::to_string(i)] = + std::shared_ptr(new LTX2VideoTransformerBlock( + inner_dim, num_attention_heads, attention_head_dim, cross_attention_dim, + audio_inner_dim, audio_num_attention_heads, audio_attention_head_dim, audio_cross_attention_dim)); + } + + blocks["proj_out"] = std::shared_ptr(new Linear(inner_dim, out_channels, true)); + blocks["audio_proj_out"] = std::shared_ptr(new Linear(audio_inner_dim, audio_out_channels, true)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* encoder_hidden_states, + ggml_tensor* timestep, + ggml_tensor* rope_cos, + ggml_tensor* rope_sin, + ggml_tensor* encoder_mask = nullptr) { + auto patchify = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto adaln = std::dynamic_pointer_cast(blocks["adaln_single"]); + auto connector = std::dynamic_pointer_cast(blocks["video_embeddings_connector"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + + const char* probe = std::getenv("LTXV_PROBE"); + const char* stage_env = std::getenv("LTXV_PROBE_STAGE"); + int stage = stage_env ? std::atoi(stage_env) : -1; - x = ggml_concat(ctx, first_frame_pad, x, 2); - x = ggml_concat(ctx, x, last_frame_pad, 2); + (void)stage; + (void)probe; + + auto& probes = debug_probes(); + probes.clear(); + + auto dup_hs = ggml_dup(ctx->ggml_ctx, hidden_states); + ggml_set_name(dup_hs, "dbg_patchify_in"); + probes.add("dbg_patchify_in", dup_hs); + + auto x = patchify->forward(ctx, hidden_states); + + auto dup_x = ggml_dup(ctx->ggml_ctx, x); + ggml_set_name(dup_x, "dbg_after_patchify"); + probes.add("dbg_after_patchify", dup_x); + if (probe && std::strcmp(probe, "after_proj_in") == 0) { + ggml_set_name(x, "ltxv_probe_out"); + return ggml_cont(ctx->ggml_ctx, x); } - x = conv->forward(ctx, x); + auto te_pair = adaln->forward(ctx, timestep); + auto temb = te_pair.first; + auto embedded_timestep = te_pair.second; + if (probe && std::strcmp(probe, "temb") == 0) { + ggml_set_name(temb, "ltxv_probe_out"); + // temb shape doesn't match transformer output expected shape; + // skip rest of forward by returning early — this corrupts the + // sampler but is acceptable for diagnostic-only runs. + return ggml_cont(ctx->ggml_ctx, temb); + } + if (probe && std::strcmp(probe, "embedded_timestep") == 0) { + ggml_set_name(embedded_timestep, "ltxv_probe_out"); + return ggml_cont(ctx->ggml_ctx, embedded_timestep); + } + + temb = ggml_reshape_4d(ctx->ggml_ctx, temb, temb->ne[0], 1, temb->ne[1], 1); + + auto encoder = connector->forward(ctx, encoder_hidden_states); + + int64_t max_i = num_layers; + const char* dbg_env = std::getenv("LTXV_DEBUG_MAX_LAYERS"); + if (dbg_env) { + int64_t dbg = std::atoi(dbg_env); + if (dbg > 0 && dbg < max_i) max_i = dbg; + } + for (int64_t i = 0; i < max_i; ++i) { + auto blk = std::dynamic_pointer_cast( + blocks["transformer_blocks." + std::to_string(i)]); + x = blk->forward(ctx, x, encoder, temb, rope_cos, rope_sin, encoder_mask, (int)i); + // Probe the first few block outputs. + if (i < 3) { + auto dup = ggml_dup(ctx->ggml_ctx, x); + std::string name = "dbg_after_block" + std::to_string(i); + ggml_set_name(dup, name.c_str()); + debug_probes().add(name, dup); + } + } + + ggml_tensor* sst = params["scale_shift_table"]; + auto et_r = ggml_reshape_4d(ctx->ggml_ctx, embedded_timestep, + inner_dim, 1, embedded_timestep->ne[1], 1); + auto sst_r = ggml_reshape_4d(ctx->ggml_ctx, sst, inner_dim, 2, 1, 1); + auto target = ggml_new_tensor_4d(ctx->ggml_ctx, et_r->type, + inner_dim, 2, et_r->ne[2], 1); + auto et_expand = ggml_repeat(ctx->ggml_ctx, et_r, target); + auto mod = ggml_add(ctx->ggml_ctx, et_expand, sst_r); + + auto shift = ggml_view_3d(ctx->ggml_ctx, mod, inner_dim, 1, mod->ne[2], + mod->nb[1], mod->nb[2], 0); + auto scale = ggml_view_3d(ctx->ggml_ctx, mod, inner_dim, 1, mod->ne[2], + mod->nb[1], mod->nb[2], mod->nb[1]); + // norm_out (LayerNorm, elementwise_affine=False, eps=1e-6) — + // matches reference LTXModel._process_output. Without this the + // post-block activations (std≈200+ after 48 layers) leak into + // the predicted velocity and the sampler diverges. + x = ggml_norm(ctx->ggml_ctx, x, 1e-6f); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x, scale)); + x = ggml_add(ctx->ggml_ctx, x, shift); + x = proj_out->forward(ctx, x); return x; } }; -}; + // ================================================================= + // Transformer runner + // ================================================================= + + // Globally-mutable "probe table" that any forward path can push named + // intermediates into. LTXVRunner::compute then reads them after running + // the graph and logs stats. Keeps the probe infrastructure out of the + // block forward signatures. + struct LTXVRunner : public GGMLRunner { + LTX2VideoTransformer3DModel dit; + RopeTables rope_tbl; + + LTXVRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_COUNT) + : GGMLRunner(backend, offload_params_to_cpu) { + dit.init(params_ctx, tensor_storage_map, prefix); + } + + // Debug: override to only run the first N transformer blocks during + // build_graph. Set via LTXV_DEBUG_MAX_LAYERS env var (0 = all). + int debug_max_layers() const { + const char* e = std::getenv("LTXV_DEBUG_MAX_LAYERS"); + return e ? std::atoi(e) : 0; + } + + std::string get_desc() override { return "ltxv2.3"; } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + dit.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context, + const sd::Tensor* mask_bias) { + auto* compute = compute_ctx; + auto gf = ggml_new_graph_custom(compute, LTXV_GRAPH_SIZE, false); + + auto x_t = make_input(x); + auto ts_t = make_input(timesteps); + auto c_t = make_input(context); + ggml_tensor* m_t = nullptr; + if (mask_bias != nullptr && !mask_bias->empty()) { + m_t = make_input(*mask_bias); + } + + int64_t W = x_t->ne[0]; + int64_t H = x_t->ne[1]; + int64_t F = x_t->ne[2]; + int64_t C = x_t->ne[3]; + GGML_ASSERT(C == dit.in_channels); + + // LTX-2.3 uses split rope → cos/sin is inner_dim/2 per position. + ggml_tensor* rope_cos = nullptr; + ggml_tensor* rope_sin = nullptr; + const char* probe_stage_env = std::getenv("LTXV_PROBE_STAGE"); + if (!probe_stage_env) { + rope_tbl = compute_rope_ltx2((int)F, (int)H, (int)W, (int)dit.inner_dim, /*split_rope=*/true); + rope_cos = ggml_new_tensor_2d(compute, GGML_TYPE_F32, + rope_tbl.dim, rope_tbl.L); + rope_sin = ggml_new_tensor_2d(compute, GGML_TYPE_F32, + rope_tbl.dim, rope_tbl.L); + set_backend_tensor_data(rope_cos, rope_tbl.cos.data()); + set_backend_tensor_data(rope_sin, rope_tbl.sin.data()); + } + + // Flatten the latent grid into tokens. Note: the exact (f, h, w) + // order implied by this permute doesn't perfectly match RoPE's + // meshgrid ordering — that's a TODO flagged in docs/ltxv.md. + // Using the previously-validated permute that at least produces + // a consistent round-trip shape. + auto hidden = ggml_ext_cont(compute, + ggml_ext_torch_permute(compute, x_t, 3, 0, 1, 2)); + hidden = ggml_reshape_3d(compute, hidden, C, W * H * F, 1); + + const char* bypass = std::getenv("LTXV_BYPASS"); + const char* stage_env = std::getenv("LTXV_PROBE_STAGE"); + bool skip_final_reshape = stage_env != nullptr; + ggml_tensor* out; + if (bypass && std::strlen(bypass) > 0) { + out = ggml_cont(compute, x_t); + } else { + auto rctx = get_context(); + out = dit.forward(&rctx, hidden, c_t, ts_t, rope_cos, rope_sin, m_t); + if (!skip_final_reshape) { + out = ggml_reshape_4d(compute, out, C, W, H, F); + out = ggml_ext_cont(compute, ggml_ext_torch_permute(compute, out, 1, 2, 3, 0)); + } + } + + // Expand probes first, then `out` last so it remains the + // graph's final node (which get_compute_graph names final_result). + for (auto& p : debug_probes().entries) { + if (p.tensor) ggml_build_forward_expand(gf, p.tensor); + } + ggml_build_forward_expand(gf, out); + return gf; + } + + // Dump min/max/mean/stddev of an sd::Tensor to the log. + // Used to locate where the forward path becomes seed-invariant. + template + static void log_tensor_stats(const char* label, const sd::Tensor& t) { + if (t.empty()) { + LOG_INFO("[ltxv.stats] %s: EMPTY", label); + return; + } + const int64_t n = t.numel(); + double mn = 1e30, mx = -1e30, sum = 0.0, sum_sq = 0.0; + size_t nan_count = 0; + const T* data = t.data(); + for (int64_t i = 0; i < n; ++i) { + double v = static_cast(data[i]); + if (std::isnan(v)) { + ++nan_count; + continue; + } + if (v < mn) mn = v; + if (v > mx) mx = v; + sum += v; + sum_sq += v * v; + } + int64_t valid = n - static_cast(nan_count); + double mean = valid > 0 ? sum / valid : 0; + double var = valid > 0 ? sum_sq / valid - mean * mean : 0; + double sd = var > 0 ? std::sqrt(var) : 0; + std::string shape_str; + for (size_t i = 0; i < t.shape().size(); ++i) { + if (i) shape_str += "x"; + shape_str += std::to_string(t.shape()[i]); + } + LOG_INFO("[ltxv.stats] %s: shape=[%s] n=%ld min=%.6g max=%.6g mean=%.6g std=%.6g nan=%zu", + label, shape_str.c_str(), (long)n, mn, mx, mean, sd, nan_count); + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context, + const sd::Tensor& mask) { + log_tensor_stats("transformer_in_x", x); + log_tensor_stats("transformer_in_timesteps", timesteps); + log_tensor_stats("transformer_in_context", context); + + const char* bypass = std::getenv("LTXV_BYPASS"); + if (bypass && std::strlen(bypass) > 0) { + // Bypass the entire transformer compute: return the input + // unchanged so the VAE sees seed-dependent data and we can + // validate the rest of the pipeline. + LOG_INFO("[ltxv.stats] transformer bypassed (LTXV_BYPASS set)"); + return x; + } + + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, + mask.empty() ? nullptr : &mask); + }; + auto result = GGMLRunner::compute(get_graph, n_threads, false); + if (!result.has_value()) return {}; + sd::Tensor out = std::move(*result); + log_tensor_stats("transformer_out", out); + // Dump any debug-tagged intermediate tensor from the graph so we + // can compare against the PyTorch reference. We enumerate every + // registered probe rather than a hardcoded list so new probes + // (e.g. blk0_*) are picked up automatically. + std::vector probe_names; + for (auto& p : debug_probes().entries) { + probe_names.push_back(p.name); + } + for (const auto& nm : probe_names) { + const char* name = nm.c_str(); + ggml_tensor* t = ggml_get_tensor(compute_ctx, name); + if (!t) continue; + const size_t nb = ggml_nbytes(t); + std::vector cpu(ggml_nelements(t)); + if (t->type == GGML_TYPE_F32) { + ggml_backend_tensor_get(t, cpu.data(), 0, nb); + } else if (t->type == GGML_TYPE_F16) { + std::vector tmp(ggml_nelements(t)); + ggml_backend_tensor_get(t, tmp.data(), 0, nb); + for (size_t i = 0; i < cpu.size(); ++i) { + cpu[i] = ggml_fp16_to_fp32(tmp[i]); + } + } else { + LOG_INFO("[ltxv.stats] %s: type=%d (skipping stats)", name, (int)t->type); + continue; + } + double mn = 1e30, mx = -1e30, sum = 0, sum_sq = 0; + size_t nan_count = 0; + for (float v : cpu) { + if (std::isnan(v)) { ++nan_count; continue; } + if (v < mn) mn = v; + if (v > mx) mx = v; + sum += v; sum_sq += v * v; + } + size_t valid = cpu.size() - nan_count; + double mean = valid > 0 ? sum / valid : 0; + double var = valid > 0 ? sum_sq / valid - mean * mean : 0; + double sd = var > 0 ? std::sqrt(var) : 0; + LOG_INFO("[ltxv.stats] %s: shape=[%lld,%lld,%lld,%lld] n=%zu min=%.4g max=%.4g mean=%.4g std=%.4g nan=%zu", + name, + (long long)t->ne[0], (long long)t->ne[1], + (long long)t->ne[2], (long long)t->ne[3], + cpu.size(), mn, mx, mean, sd, nan_count); + } + return out; + } + }; + + // ================================================================= + // LTX-2.3 VAE + // ================================================================= + // + // Structure inferred from `ltx-2.3-22b-dev.safetensors`. + // Encoder has 9 top-level `down_blocks.N` groups alternating res-stacks and + // downsampler convs: + // 0: res × 4 @ 128 1: spatial(1,2,2) 128→256 2: res × 6 @ 256 + // 3: temporal(2,1,1) 256→512 4: res × 4 @ 512 5: st(2,2,2) 512→1024 + // 6: res × 2 @ 1024 7: st(2,2,2) 1024→1024 8: res × 2 @ 1024 + // Decoder mirror (sizes from checkpoint): + // 0: res × 2 @ 1024 1: upsamp st(2,2,2) conv[4096,1024] → 512 + // 2: res × 2 @ 512 3: upsamp st(2,2,2) conv[4096,512] → 512 + // 4: res × 4 @ 512 5: upsamp temporal(2,1,1) conv[512,512] → 256 + // 6: res × 6 @ 256 7: upsamp spatial(1,2,2) conv[512,256] → 128 + // 8: res × 4 @ 128 + + // VAE residual block — diffusers' LTX2VideoResnetBlock3d simplified for + // the LTX-2.3 checkpoint layout (no timestep conditioning, no shortcut + // conv, no learned affine in the norms). + // norm1 (PerChannelRMSNorm stateless) → silu → conv1 → + // norm2 → silu → conv2 → + residual + class VAEResBlock : public GGMLBlock { + protected: + int64_t channels; + + public: + VAEResBlock(int64_t channels) : channels(channels) { + // PerChannelRMSNorm is stateless (no weight), so the checkpoint + // has no norm tensors for these — we just do the arithmetic + // before each conv to keep activations bounded. + blocks["conv1"] = std::shared_ptr(new CausalConv3d(channels, channels, {3, 3, 3})); + blocks["conv2"] = std::shared_ptr(new CausalConv3d(channels, channels, {3, 3, 3})); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* h, bool causal = true) { + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + auto residual = h; + // Stateless per-channel RMS normalisation to bound activations. + h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, h, 3, 0, 1, 2)); + h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-8f); + h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h, causal); + h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, h, 3, 0, 1, 2)); + h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-8f); + h = ggml_ext_cont(ctx->ggml_ctx, + ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h, causal); + return ggml_add(ctx->ggml_ctx, h, residual); + } + }; + + class VAEResStack : public GGMLBlock { + protected: + int64_t num_layers; + + public: + VAEResStack(int64_t channels, int64_t num_layers) : num_layers(num_layers) { + for (int64_t i = 0; i < num_layers; ++i) { + blocks["res_blocks." + std::to_string(i)] = + std::shared_ptr(new VAEResBlock(channels)); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* h, bool causal = true) { + for (int64_t i = 0; i < num_layers; ++i) { + auto rn = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); + h = rn->forward(ctx, h, causal); + } + return h; + } + }; + + // Downsampler: conv (in_ch → conv_out_ch) then channel-inflation via reshape. + class VAEDownsampler : public GGMLBlock { + protected: + int64_t in_channels; + int64_t conv_out_channels; + std::tuple stride; + + public: + VAEDownsampler(int64_t in_channels, int64_t conv_out_channels, std::tuple stride) + : in_channels(in_channels), conv_out_channels(conv_out_channels), stride(stride) { + blocks["conv"] = std::shared_ptr(new CausalConv3d(in_channels, conv_out_channels, {3, 3, 3})); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* h, bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + h = conv->forward(ctx, h, causal); + int st_t = std::get<0>(stride), st_h = std::get<1>(stride), st_w = std::get<2>(stride); + int64_t W = h->ne[0], H = h->ne[1], F = h->ne[2], C = h->ne[3]; + h = ggml_cont(ctx->ggml_ctx, h); + h = ggml_reshape_4d(ctx->ggml_ctx, h, W / st_w, H / st_h, F / st_t, C * st_w * st_h * st_t); + return h; + } + }; + + class VAEUpsampler : public GGMLBlock { + protected: + int64_t in_channels; + int64_t conv_out_channels; + std::tuple stride; + + public: + VAEUpsampler(int64_t in_channels, int64_t conv_out_channels, std::tuple stride) + : in_channels(in_channels), conv_out_channels(conv_out_channels), stride(stride) { + blocks["conv"] = std::shared_ptr(new CausalConv3d(in_channels, conv_out_channels, {3, 3, 3})); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* h, bool causal = false) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + h = conv->forward(ctx, h, causal); + int st_t = std::get<0>(stride), st_h = std::get<1>(stride), st_w = std::get<2>(stride); + h = ggml_cont(ctx->ggml_ctx, h); + h = depth_to_space_3d(ctx->ggml_ctx, h, st_t, st_h, st_w); + // Diffusers LTX2VideoUpsampler3d drops the first (st_t - 1) temporal + // samples so each upsampled chunk boundary stays causal and the + // overall frame count follows f_out = (f_in - 1) * st_t + 1 when + // composed across multiple temporal upsamples. + if (st_t > 1) { + int64_t T_out = h->ne[2]; + int64_t T_keep = T_out - (st_t - 1); + int64_t offset_bytes = h->nb[2] * (st_t - 1); + h = ggml_view_4d(ctx->ggml_ctx, h, + h->ne[0], h->ne[1], T_keep, h->ne[3], + h->nb[1], h->nb[2], h->nb[3], offset_bytes); + h = ggml_cont(ctx->ggml_ctx, h); + } + return h; + } + }; + + class LTX23Encoder3d : public GGMLBlock { + public: + LTX23Encoder3d() { + blocks["conv_in"] = std::shared_ptr(new CausalConv3d(48, 128, {3, 3, 3})); + blocks["down_blocks.0"] = std::shared_ptr(new VAEResStack(128, 4)); + blocks["down_blocks.1"] = std::shared_ptr(new VAEDownsampler(128, 64, {1, 2, 2})); + blocks["down_blocks.2"] = std::shared_ptr(new VAEResStack(256, 6)); + blocks["down_blocks.3"] = std::shared_ptr(new VAEDownsampler(256, 256, {2, 1, 1})); + blocks["down_blocks.4"] = std::shared_ptr(new VAEResStack(512, 4)); + blocks["down_blocks.5"] = std::shared_ptr(new VAEDownsampler(512, 128, {2, 2, 2})); + blocks["down_blocks.6"] = std::shared_ptr(new VAEResStack(1024, 2)); + blocks["down_blocks.7"] = std::shared_ptr(new VAEDownsampler(1024, 128, {2, 2, 2})); + blocks["down_blocks.8"] = std::shared_ptr(new VAEResStack(1024, 2)); + blocks["conv_out"] = std::shared_ptr(new CausalConv3d(1024, 129, {3, 3, 3})); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, bool causal = true) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + int64_t W = x->ne[0], H = x->ne[1], F = x->ne[2], C = x->ne[3]; + GGML_ASSERT(W % 4 == 0 && H % 4 == 0); + // TODO: the reference patchify (ops.py:patchify) follows the + // "b c (f p) (h q) (w r) -> b (c p r q) f h w" + // convention where q (h_patch) is innermost in the channel axis. + // This bare reshape does not honour that — for T2V the encoder + // path is unused, but v2v/i2v workflows will need the inverse of + // depth_to_space_3d_patch here before we can trust them. + x = ggml_cont(ctx->ggml_ctx, x); + x = ggml_reshape_4d(ctx->ggml_ctx, x, W / 4, H / 4, F, C * 16); + auto h = conv_in->forward(ctx, x, causal); + for (int i = 0; i < 9; ++i) { + auto& blk = blocks["down_blocks." + std::to_string(i)]; + if (i % 2 == 0) { + auto s = std::dynamic_pointer_cast(blk); + h = s->forward(ctx, h, causal); + } else { + auto s = std::dynamic_pointer_cast(blk); + h = s->forward(ctx, h, causal); + } + } + h = conv_out->forward(ctx, h, causal); + return h; + } + }; + + class LTX23Decoder3d : public GGMLBlock { + public: + LTX23Decoder3d() { + blocks["conv_in"] = std::shared_ptr(new CausalConv3d(128, 1024, {3, 3, 3})); + blocks["up_blocks.0"] = std::shared_ptr(new VAEResStack(1024, 2)); + blocks["up_blocks.1"] = std::shared_ptr(new VAEUpsampler(1024, 4096, {2, 2, 2})); + blocks["up_blocks.2"] = std::shared_ptr(new VAEResStack(512, 2)); + blocks["up_blocks.3"] = std::shared_ptr(new VAEUpsampler(512, 4096, {2, 2, 2})); + blocks["up_blocks.4"] = std::shared_ptr(new VAEResStack(512, 4)); + blocks["up_blocks.5"] = std::shared_ptr(new VAEUpsampler(512, 512, {2, 1, 1})); + blocks["up_blocks.6"] = std::shared_ptr(new VAEResStack(256, 6)); + blocks["up_blocks.7"] = std::shared_ptr(new VAEUpsampler(256, 512, {1, 2, 2})); + blocks["up_blocks.8"] = std::shared_ptr(new VAEResStack(128, 4)); + blocks["conv_out"] = std::shared_ptr(new CausalConv3d(128, 48, {3, 3, 3})); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* z, bool causal = false) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + auto h = conv_in->forward(ctx, z, causal); + for (int i = 0; i < 9; ++i) { + auto& blk = blocks["up_blocks." + std::to_string(i)]; + if (i % 2 == 0) { + auto s = std::dynamic_pointer_cast(blk); + h = s->forward(ctx, h, causal); + } else { + auto s = std::dynamic_pointer_cast(blk); + h = s->forward(ctx, h, causal); + } + } + // conv_norm_out (stateless PerChannelRMSNorm) + SiLU before conv_out, + // matching the reference video_vae decoder. Without these the + // output is O(1000) instead of O(1) per pixel. + { + PerChannelRMSNorm pn; + h = pn.forward(ctx, h); + } + h = ggml_silu(ctx->ggml_ctx, h); + h = conv_out->forward(ctx, h, causal); + // Un-patchify 4×4 spatial pack: ne [W, H, F, C*16] → [W*4, H*4, F, C]. + // The reference ops.py uses the patchify convention + // "b (c p r q) f h w -> b c (f p) (h q) (w r)" + // where the channel axis has h_patch (q) as the INNERMOST + // sub-index — not w_patch as in the intermediate upsampler. + // depth_to_space_3d_patch handles the sub-axis swap. + h = ggml_cont(ctx->ggml_ctx, h); + h = depth_to_space_3d_patch(ctx->ggml_ctx, h, /*p_t=*/1, /*p_h=*/4, /*p_w=*/4); + // sd.cpp's decode_video_outputs expects the 5-D layout + // [W, H, T, C, N=1] + // (batch last, time before channel). Our 4-D result is + // [W, H, T, C] — reinterpret by prepending N=1 to match. + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3]); + // NOTE: ggml tensors are 4-D max. sd.cpp's tensor_to_sd_image + // reads the dimensionality from the sd::Tensor's shape vector + // (not from ggml ne), so we need to ensure the C++ side sees a + // 5-D shape. That happens in LTXVVAERunner::_compute by + // unsqueezing the resulting sd::Tensor before returning. + return h; + } + }; + + class LTX23Autoencoder : public GGMLBlock { + public: + bool decode_only; + + protected: + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { + params["per_channel_statistics.mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128); + params["per_channel_statistics.std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128); + } + + public: + LTX23Autoencoder(bool decode_only = true) : decode_only(decode_only) { + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new LTX23Encoder3d()); + } + blocks["decoder"] = std::shared_ptr(new LTX23Decoder3d()); + } + + ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) { + auto dec = std::dynamic_pointer_cast(blocks["decoder"]); + return dec->forward(ctx, z, /*causal=*/false); + } + + ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto enc = std::dynamic_pointer_cast(blocks["encoder"]); + return enc->forward(ctx, x, /*causal=*/true); + } + }; + + struct LTXVVAERunner : public VAE { + bool decode_only = true; + LTX23Autoencoder ae; + + LTXVVAERunner(SDVersion version, + ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "vae", + bool decode_only = true) + : VAE(version, backend, offload_params_to_cpu), + decode_only(decode_only), + ae(decode_only) { + // Keep scale_input=true (the sd.cpp default): the VAE::decode + // output is mapped (x + 1) / 2 into [0, 1] before the frame + // extraction. LTX-2.3's VAE is trained to produce values in + // roughly [-1, 1] per-channel so this is the correct range. + // scale_input = false // <-- was here, caused black frames + ae.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { return "ltxv2.3_vae"; } + + void get_param_tensors(std::map& tensors, const std::string prefix) override { + ae.get_param_tensors(tensors, prefix); + } + + int get_encoder_output_channels(int input_channels) override { + SD_UNUSED(input_channels); + return 129; + } + + sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, + std::shared_ptr rng) override { + SD_UNUSED(rng); + return vae_output; + } + + // LTX-2.3 normalises diffusion-space latents to unit variance using the + // per-channel stats saved with the VAE: + // diffusion_to_vae (un_normalize) = latents * std + mean + // vae_to_diffusion (normalize) = (latents - mean) / std + // The stats live in the backend under `ae.params["per_channel_statistics.*"]`; + // we materialise them to CPU lazily on the first call. + std::vector mean_of_means; + std::vector std_of_means; + bool stats_loaded = false; + void load_stats_cpu() { + if (stats_loaded) return; + std::map tensors; + ae.get_param_tensors(tensors); + auto mm = tensors.find("per_channel_statistics.mean-of-means"); + auto sm = tensors.find("per_channel_statistics.std-of-means"); + if (mm == tensors.end() || sm == tensors.end() || !mm->second || !sm->second) return; + ggml_tensor* m = mm->second; + ggml_tensor* s = sm->second; + int64_t C = m->ne[0]; + mean_of_means.resize(C); + std_of_means.resize(C); + ggml_backend_tensor_get(m, mean_of_means.data(), 0, C * sizeof(float)); + ggml_backend_tensor_get(s, std_of_means.data(), 0, C * sizeof(float)); + stats_loaded = true; + LOG_INFO("[ltxv.stats] per-channel stats loaded: C=%lld mean[0..3]=%g %g %g std[0..3]=%g %g %g", + (long long)C, + mean_of_means[0], mean_of_means[1], mean_of_means[2], + std_of_means[0], std_of_means[1], std_of_means[2]); + } + // latents shape: [W, H, F, C, N] or [W, H, F, C] (missing batch axis). + // The data layout is row-major with shape[0] the fastest-varying dim, + // so index(w,h,f,c,n) = n*W*H*F*C + c*W*H*F + f*W*H + h*W + w. + sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { + load_stats_cpu(); + if (!stats_loaded) return latents; + sd::Tensor out(latents.shape()); + const auto& sh = latents.shape(); + int64_t W = sh.size() > 0 ? sh[0] : 1; + int64_t H = sh.size() > 1 ? sh[1] : 1; + int64_t F = sh.size() > 2 ? sh[2] : 1; + int64_t C = sh.size() > 3 ? sh[3] : 1; + int64_t N = sh.size() > 4 ? sh[4] : 1; + if ((size_t)C != mean_of_means.size()) return latents; + const float* src = latents.data(); + float* dst = out.data(); + int64_t plane = W * H * F; + for (int64_t n = 0; n < N; ++n) { + for (int64_t c = 0; c < C; ++c) { + float mu = mean_of_means[c]; + float sg = std_of_means[c]; + int64_t off = (n * C + c) * plane; + for (int64_t i = 0; i < plane; ++i) { + dst[off + i] = src[off + i] * sg + mu; + } + } + } + return out; + } + sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { + load_stats_cpu(); + if (!stats_loaded) return latents; + sd::Tensor out(latents.shape()); + const auto& sh = latents.shape(); + int64_t W = sh.size() > 0 ? sh[0] : 1; + int64_t H = sh.size() > 1 ? sh[1] : 1; + int64_t F = sh.size() > 2 ? sh[2] : 1; + int64_t C = sh.size() > 3 ? sh[3] : 1; + int64_t N = sh.size() > 4 ? sh[4] : 1; + if ((size_t)C != mean_of_means.size()) return latents; + const float* src = latents.data(); + float* dst = out.data(); + int64_t plane = W * H * F; + for (int64_t n = 0; n < N; ++n) { + for (int64_t c = 0; c < C; ++c) { + float mu = mean_of_means[c]; + float sg = std_of_means[c]; + int64_t off = (n * C + c) * plane; + for (int64_t i = 0; i < plane; ++i) { + dst[off + i] = (src[off + i] - mu) / sg; + } + } + } + return out; + } + + protected: + struct ggml_cgraph* build_graph_decode(const sd::Tensor& z) { + auto gf = ggml_new_graph_custom(compute_ctx, LTXV_GRAPH_SIZE, false); + auto z_t = make_input(z); + auto rctx = get_context(); + auto h = ae.decode(&rctx, z_t); + ggml_build_forward_expand(gf, h); + return gf; + } + + struct ggml_cgraph* build_graph_encode(const sd::Tensor& x) { + auto gf = ggml_new_graph_custom(compute_ctx, LTXV_GRAPH_SIZE, false); + auto x_t = make_input(x); + auto rctx = get_context(); + auto h = ae.encode(&rctx, x_t); + ggml_build_forward_expand(gf, h); + return gf; + } + + sd::Tensor _compute(const int n_threads, + const sd::Tensor& z, + bool decode_graph) override { + LTXVRunner::log_tensor_stats(decode_graph ? "vae_in_decode_z" : "vae_in_encode_x", z); + auto get_graph = [&]() -> struct ggml_cgraph* { + return decode_graph ? build_graph_decode(z) : build_graph_encode(z); + }; + auto result = GGMLRunner::compute(get_graph, n_threads, false); + if (!result.has_value()) return {}; + sd::Tensor out = std::move(*result); + LTXVRunner::log_tensor_stats(decode_graph ? "vae_out_decode" : "vae_out_encode", out); + if (decode_graph && out.dim() == 4) { + out.unsqueeze_(out.dim()); + } + return out; + } + }; + +} // namespace LTXV -#endif \ No newline at end of file +#endif // __LTXV_HPP__ diff --git a/src/model.cpp b/src/model.cpp index 3479a0bea..fa6fe7885 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -450,6 +450,21 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { return VERSION_SD3; } + // LTX-Video 2.3: unique audio-visual weights distinguish it from every + // other DiT family. Any of these top-level tensors is present only in + // the joint audio-visual LTX-2.3 architecture: + // * audio_scale_shift_table (2,2048) — per-modality final modulation + // * audio_patchify_proj — audio latent input projection + // * audio_adaln_single — audio timestep embedder + // * av_ca_video_scale_shift_adaln_single — a2v cross-attn modulation + // * video_embeddings_connector — the LTX-2.3 prompt re-embedder + if (tensor_storage.name == "model.diffusion_model.audio_scale_shift_table" || + tensor_storage.name.find("model.diffusion_model.audio_patchify_proj.") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.audio_adaln_single.") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.av_ca_video_scale_shift_adaln_single.") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.video_embeddings_connector.") != std::string::npos) { + return VERSION_LTXV2; + } if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { return VERSION_QWEN_IMAGE; } diff --git a/src/model.h b/src/model.h index 65bc6c367..a319c4ba7 100644 --- a/src/model.h +++ b/src/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_ERNIE_IMAGE, + VERSION_LTXV2, VERSION_COUNT, }; @@ -139,6 +140,13 @@ static inline bool sd_version_is_ernie_image(SDVersion version) { return false; } +static inline bool sd_version_is_ltxv2(SDVersion version) { + if (version == VERSION_LTXV2) { + return true; + } + return false; +} + static inline bool sd_version_uses_flux2_vae(SDVersion version) { if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) { return true; @@ -165,7 +173,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_z_image(version) || - sd_version_is_ernie_image(version)) { + sd_version_is_ernie_image(version) || + sd_version_is_ltxv2(version)) { return true; } return false; diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c65411489..94ed2fcb2 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -1,3 +1,5 @@ +#include + #include "ggml_extend.hpp" #include "model.h" @@ -564,6 +566,85 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map, "model.diffusion_model"); + } else if (sd_version_is_ltxv2(version)) { + // LTX-2.3 uses Gemma-3-12B as its text encoder. The encoder + // weights live OUTSIDE the 22B safetensors — the caller + // points `text_encoder_path` at a directory containing + // `tokenizer.model` plus Gemma safetensors shards. The + // 4096-dim aggregate Linear (`text_embedding_projection. + // video_aggregate_embed`) IS in the 22B checkpoint and we + // wire it into LTXV2Conditioner. + auto ltxv_cond = std::make_shared(4096, 128); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); + // Build the projection runner from the LTX 22B safetensors. + // Its weights are loaded together with the rest of the + // LTX tensors further down (we register them in `tensors`). + auto proj = std::make_shared( + clip_backend, offload_params_to_cpu, tensor_storage_map, + /*in=*/188160, /*out=*/4096); + if (!proj->alloc_params_buffer()) { + LOG_ERROR("text_embedding_projection params buffer alloc failed"); + return false; + } + proj->get_param_tensors(tensors, + "text_embedding_projection.video_aggregate_embed"); + // If a Gemma directory was provided, load it (heavy). + const char* gemma_dir = SAFE_STR(sd_ctx_params->text_encoder_path); + if (gemma_dir && gemma_dir[0] != '\0') { + std::string tok_path = std::string(gemma_dir) + "/tokenizer.model"; + auto tok = std::make_shared(); + std::string terr; + if (!tok->load_from_spm(tok_path, &terr)) { + LOG_WARN("failed to load Gemma tokenizer at %s: %s", + tok_path.c_str(), terr.c_str()); + } else { + // Enumerate safetensors shards in the directory. + std::vector gemma_files; + if (DIR* d = opendir(gemma_dir)) { + struct dirent* e; + while ((e = readdir(d)) != nullptr) { + std::string name = e->d_name; + if (name.size() > 12 && + name.substr(name.size() - 12) == ".safetensors") { + gemma_files.push_back(std::string(gemma_dir) + "/" + name); + } + } + closedir(d); + std::sort(gemma_files.begin(), gemma_files.end()); + } + ModelLoader gemma_loader; + bool loaded_any = false; + for (const auto& f : gemma_files) { + if (gemma_loader.init_from_file(f, /*prefix=*/"language_model.")) { + loaded_any = true; + } + } + if (loaded_any) { + auto gemma = std::make_shared( + clip_backend, offload_params_to_cpu, + gemma_loader.get_tensor_storage_map(), + /*prefix=*/"model"); + gemma->alloc_params_buffer(); + std::map gt; + gemma->get_param_tensors(gt, "language_model.model"); + if (gemma_loader.load_tensors(gt, /*ignore=*/{}, n_threads)) { + ltxv_cond->attach_gemma(gemma, tok, proj); + LOG_INFO("LTX-2.3 Gemma-3 text encoder loaded"); + } else { + LOG_WARN("failed to load Gemma tensors"); + } + } else { + LOG_WARN("failed to enumerate Gemma shards at %s", gemma_dir); + } + } + } else { + LOG_INFO("LTX-2.3: no text_encoder_path set — running unconditional"); + } + cond_stage_model = ltxv_cond; } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -638,6 +719,14 @@ class StableDiffusionGGML { }; auto create_vae = [&]() -> std::shared_ptr { + if (sd_version_is_ltxv2(version)) { + return std::make_shared(version, + vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only); + } if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { @@ -836,6 +925,14 @@ class StableDiffusionGGML { ignore_tensors.insert("text_encoders.llm.vision_tower."); ignore_tensors.insert("text_encoders.llm.multi_modal_projector."); } + if (sd_version_is_ltxv2(version)) { + // LTX-2.3 single-file checkpoints also contain audio VAE and a + // vocoder that the video-only pipeline does not consume. + // `text_embedding_projection.*` IS consumed when the conditioner + // is wired up with a Gemma-3 text encoder (see LTXV2Conditioner). + ignore_tensors.insert("audio_vae."); + ignore_tensors.insert("vocoder."); + } bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); if (!success) { LOG_ERROR("load tensors from model loader failed"); @@ -940,12 +1037,17 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_ernie_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_ltxv2(version)) { pred_type = FLOW_PRED; if (sd_version_is_wan(version)) { default_flow_shift = 5.f; } else if (sd_version_is_ernie_image(version)) { default_flow_shift = 4.f; + } else if (sd_version_is_ltxv2(version)) { + // LTX uses dynamic shift in diffusers (shape-dependent). + // Use a fixed default; tune per hardware-verification run. + default_flow_shift = 3.f; } else { default_flow_shift = 3.f; } @@ -1866,6 +1968,8 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_uses_flux2_vae(version)) { latent_channel = 128; + } else if (sd_version_is_ltxv2(version)) { + latent_channel = 128; } else { latent_channel = 16; } @@ -1888,6 +1992,9 @@ class StableDiffusionGGML { int T = frames; if (sd_version_is_wan(version)) { T = ((T - 1) / 4) + 1; + } else if (sd_version_is_ltxv2(version)) { + // LTX VAE temporal compression factor = 8 + T = ((T - 1) / 8) + 1; } int C = get_latent_channel(); if (video) { @@ -2619,7 +2726,14 @@ struct GenerationRequest { negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); width = sd_vid_gen_params->width; height = sd_vid_gen_params->height; - frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1; + // Pad frame count to what each VAE family can decode. + // Wan temporal compression = 4 → frames must be 4k+1. + // LTX temporal compression = 8 → frames must be 8k+1. + { + SDVersion ver = sd_ctx->sd->version; + int temporal_grid = sd_version_is_ltxv2(ver) ? 8 : 4; + frames = (sd_vid_gen_params->video_frames - 1) / temporal_grid * temporal_grid + 1; + } clip_skip = sd_vid_gen_params->clip_skip; vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); @@ -2817,6 +2931,18 @@ struct SamplePlan { high_noise_sample_steps = total_steps - sample_steps; LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps); } + } else if (sd_version_is_ltxv2(sd_ctx->sd->version) && total_steps == 8) { + // LTX-2.3 distilled default schedule — a hand-tuned non-linear + // sigma sequence clustered near 1 with a sharp drop at the end, + // per `DISTILLED_SIGMA_VALUES` in ltx_pipelines.utils.constants. + // Applied only when the user asked for exactly 8 sampling + // steps (the distilled model's target). Otherwise fall through + // to the generic shifted flow schedule. + sigmas = {1.0f, 0.99375f, 0.9875f, 0.98125f, + 0.975f, 0.909375f, 0.725f, 0.421875f, 0.0f}; + total_steps = 8; + sample_steps = 8; + LOG_INFO("Using LTX-2.3 distilled 8-step sigma schedule"); } else { scheduler_t scheduler = resolve_scheduler(sd_ctx, sample_params->scheduler, diff --git a/src/tokenizers/gemma3_tokenizer.cpp b/src/tokenizers/gemma3_tokenizer.cpp new file mode 100644 index 000000000..6b629a4cc --- /dev/null +++ b/src/tokenizers/gemma3_tokenizer.cpp @@ -0,0 +1,325 @@ +// Gemma-3 SentencePiece BPE tokenizer — implementation. +// +// Protobuf wire format for sentencepiece.ModelProto (only the fields we +// care about): +// message ModelProto { +// repeated SentencePiece pieces = 1; +// // ... unused fields +// } +// message SentencePiece { +// string piece = 1; +// float score = 2; +// Type type = 3; // enum, wire-type = varint +// } +// +// We parse exactly this subset — everything else (trainer_spec, +// normalizer_spec, etc.) is skipped via tag/length walks. + +#include "gemma3_tokenizer.h" + +#include +#include +#include +#include +#include + +namespace { + +// --- protobuf wire format helpers ------------------------------------------ + +struct Reader { + const uint8_t* p; + const uint8_t* end; + + bool eof() const { return p >= end; } + + bool read_varint(uint64_t& out) { + out = 0; + int shift = 0; + while (p < end) { + uint8_t b = *p++; + out |= (uint64_t)(b & 0x7f) << shift; + if ((b & 0x80) == 0) return true; + shift += 7; + if (shift >= 64) return false; + } + return false; + } + + bool read_fixed32(uint32_t& out) { + if (end - p < 4) return false; + out = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | + ((uint32_t)p[2] << 16) | ((uint32_t)p[3] << 24); + p += 4; + return true; + } + + bool read_fixed64(uint64_t& out) { + if (end - p < 8) return false; + uint64_t v = 0; + for (int i = 0; i < 8; ++i) v |= (uint64_t)p[i] << (8 * i); + p += 8; + out = v; + return true; + } + + // Skip a field of the given wire type (for unused sections). + bool skip_field(int wire_type) { + if (wire_type == 0) { // varint + uint64_t tmp; + return read_varint(tmp); + } else if (wire_type == 1) { // fixed64 + uint64_t tmp; + return read_fixed64(tmp); + } else if (wire_type == 2) { // length-delimited + uint64_t len; + if (!read_varint(len)) return false; + if ((uint64_t)(end - p) < len) return false; + p += len; + return true; + } else if (wire_type == 5) { // fixed32 + uint32_t tmp; + return read_fixed32(tmp); + } + return false; + } +}; + +// Parse one SentencePiece message from `sub` (length-delimited sub-view). +bool parse_piece(const uint8_t* data, size_t len, Gemma3Tokenizer::Piece& out) { + Reader r{data, data + len}; + out = {}; + out.type = Gemma3Tokenizer::NORMAL; + out.score = 0.0f; + while (!r.eof()) { + uint64_t tag; + if (!r.read_varint(tag)) return false; + int field = (int)(tag >> 3); + int wire = (int)(tag & 0x07); + if (field == 1 && wire == 2) { + uint64_t slen; + if (!r.read_varint(slen)) return false; + if ((uint64_t)(r.end - r.p) < slen) return false; + out.text.assign((const char*)r.p, slen); + r.p += slen; + } else if (field == 2 && wire == 5) { + uint32_t bits; + if (!r.read_fixed32(bits)) return false; + float f; + std::memcpy(&f, &bits, 4); + out.score = f; + } else if (field == 3 && wire == 0) { + uint64_t v; + if (!r.read_varint(v)) return false; + out.type = (uint8_t)v; + } else { + if (!r.skip_field(wire)) return false; + } + } + return true; +} + +} // namespace + +bool Gemma3Tokenizer::load_from_spm(const std::string& path, std::string* error) { + std::ifstream f(path, std::ios::binary); + if (!f) { + if (error) *error = "cannot open " + path; + return false; + } + std::vector buf((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + + Reader r{buf.data(), buf.data() + buf.size()}; + pieces_.clear(); + piece_to_id_.clear(); + + while (!r.eof()) { + uint64_t tag; + if (!r.read_varint(tag)) { + if (error) *error = "truncated protobuf"; + return false; + } + int field = (int)(tag >> 3); + int wire = (int)(tag & 0x07); + + if (field == 1 && wire == 2) { + uint64_t slen; + if (!r.read_varint(slen)) return false; + if ((uint64_t)(r.end - r.p) < slen) return false; + Piece p; + if (!parse_piece(r.p, slen, p)) { + if (error) *error = "malformed SentencePiece"; + return false; + } + piece_to_id_[p.text] = (int32_t)pieces_.size(); + pieces_.push_back(std::move(p)); + r.p += slen; + } else { + if (!r.skip_field(wire)) { + if (error) *error = "cannot skip unknown field"; + return false; + } + } + } + + // Locate special tokens by name. Gemma's convention matches llama.cpp. + auto find = [&](const std::string& s, int32_t fallback) -> int32_t { + auto it = piece_to_id_.find(s); + return it == piece_to_id_.end() ? fallback : it->second; + }; + pad_id_ = find("", 0); + eos_id_ = find("", 1); + bos_id_ = find("", 2); + unk_id_ = find("", 3); + + return true; +} + +// Gemma's meta-space prefix byte sequence: U+2581 (LOWER ONE EIGHTH BLOCK) +// encoded as UTF-8: 0xE2 0x96 0x81 (three bytes). +static const std::string kMetaSpace = "\xE2\x96\x81"; + +// Byte-level fallback: SentencePiece encodes unknown bytes as +// "<0xHH>" pieces (tokens 6..261 cover 0x00..0xFF). Gemma uses the same. +static std::string byte_piece(uint8_t b) { + static const char hex[] = "0123456789ABCDEF"; + char buf[8]; + buf[0] = '<'; buf[1] = '0'; buf[2] = 'x'; + buf[3] = hex[(b >> 4) & 0xf]; + buf[4] = hex[b & 0xf]; + buf[5] = '>'; buf[6] = 0; + return std::string(buf, 6); +} + +// Classic SentencePiece BPE: split input into unicode chars prefixed with +// meta-space for word boundaries, then repeatedly merge adjacent pairs +// using piece scores (higher = earlier) until no merge is possible. +// +// `word` here is the raw UTF-8 string including its leading meta-space. +void Gemma3Tokenizer::bpe_encode_word(const std::string& word, + std::vector& out) const { + if (word.empty()) return; + + // 1) Break into individual unicode code points (1-4 byte UTF-8 runs), + // each represented by its piece string. If the codepoint has no + // direct piece, fall back to its bytes as <0xHH> pieces. + struct Sym { + std::string text; + int32_t id; + }; + std::vector syms; + size_t i = 0; + while (i < word.size()) { + // Figure out UTF-8 run length at byte i. + uint8_t c = (uint8_t)word[i]; + size_t len = 1; + if ((c & 0x80) == 0) len = 1; + else if ((c & 0xE0) == 0xC0) len = 2; + else if ((c & 0xF0) == 0xE0) len = 3; + else if ((c & 0xF8) == 0xF0) len = 4; + if (i + len > word.size()) len = 1; + + std::string cp = word.substr(i, len); + auto it = piece_to_id_.find(cp); + if (it != piece_to_id_.end()) { + syms.push_back({cp, it->second}); + } else { + // Fallback to byte-level pieces. + for (size_t b = 0; b < len; ++b) { + std::string bp = byte_piece((uint8_t)word[i + b]); + auto bit = piece_to_id_.find(bp); + int32_t id = bit == piece_to_id_.end() ? unk_id_ : bit->second; + syms.push_back({bp, id}); + } + } + i += len; + } + + // 2) Merge adjacent pairs. Use a priority-based loop: at each step, + // scan for the best merge (highest score), apply it, repeat. + // This is O(N^2) in word length but prompts are short (<300 tokens), + // so it's fine. + while (syms.size() > 1) { + float best_score = -std::numeric_limits::infinity(); + int best_idx = -1; + int32_t best_id = -1; + std::string best_text; + for (size_t k = 0; k + 1 < syms.size(); ++k) { + std::string merged = syms[k].text + syms[k + 1].text; + auto it = piece_to_id_.find(merged); + if (it == piece_to_id_.end()) continue; + float score = pieces_[it->second].score; + if (score > best_score) { + best_score = score; + best_idx = (int)k; + best_id = it->second; + best_text = std::move(merged); + } + } + if (best_idx < 0) break; + syms[best_idx] = {std::move(best_text), best_id}; + syms.erase(syms.begin() + best_idx + 1); + } + + for (auto& s : syms) out.push_back(s.id); +} + +std::vector Gemma3Tokenizer::encode(const std::string& text, + bool add_bos, + bool add_eos) const { + std::vector ids; + if (add_bos) ids.push_back(bos_id_); + + // Pre-tokenisation: SentencePiece replaces spaces with the meta-space + // character. Gemma-3 disables the "add_dummy_prefix" behaviour, so the + // FIRST word is encoded *without* a leading meta-space (the first chunk + // has no prefix), but subsequent words get one. + std::string normalised; + normalised.reserve(text.size() + kMetaSpace.size() * 4); + for (size_t i = 0; i < text.size(); ++i) { + char c = text[i]; + if (c == ' ') normalised += kMetaSpace; + else normalised += c; + } + + // Split into chunks. The first chunk (before any meta-space) is encoded + // as-is; subsequent chunks (starting at a meta-space boundary) include + // their leading meta-space as part of the word. + if (!normalised.empty()) { + size_t first_ms = normalised.find(kMetaSpace); + size_t end0 = first_ms == std::string::npos ? normalised.size() : first_ms; + if (end0 > 0) { + bpe_encode_word(normalised.substr(0, end0), ids); + } + size_t pos = first_ms; + while (pos != std::string::npos && pos < normalised.size()) { + size_t next = normalised.find(kMetaSpace, pos + kMetaSpace.size()); + if (next == std::string::npos) next = normalised.size(); + std::string word = normalised.substr(pos, next - pos); + bpe_encode_word(word, ids); + pos = next; + } + } + + if (add_eos) ids.push_back(eos_id_); + return ids; +} + +std::string Gemma3Tokenizer::decode(const std::vector& ids) const { + std::string out; + for (int32_t id : ids) { + if (id < 0 || id >= (int32_t)pieces_.size()) continue; + const auto& p = pieces_[id]; + if (p.type == CONTROL) continue; // skip BOS/EOS/pad + std::string piece = p.text; + // Convert back: meta-space → regular space. + size_t pos = 0; + while ((pos = piece.find(kMetaSpace, pos)) != std::string::npos) { + piece.replace(pos, kMetaSpace.size(), " "); + pos += 1; + } + out += piece; + } + return out; +} diff --git a/src/tokenizers/gemma3_tokenizer.h b/src/tokenizers/gemma3_tokenizer.h new file mode 100644 index 000000000..ca2c027d5 --- /dev/null +++ b/src/tokenizers/gemma3_tokenizer.h @@ -0,0 +1,79 @@ +// Gemma-3 SentencePiece BPE tokenizer. +// +// Reads a raw `tokenizer.model` protobuf file (same format HuggingFace +// transformers and llama.cpp consume) and performs byte-level BPE encoding +// using the piece scores as merge priorities. +// +// SentencePiece vocab layout (for Gemma-3-12B): +// 262208 total pieces. First 4 are special control/unknown tokens +// ( id=0, id=1, id=2, id=3, id=4). +// Most pieces are normal sub-word tokens; a small number are CONTROL +// or USER_DEFINED (BOS/EOS/pad/mask/turn markers). +// +// For LTX-2.3, we tokenise the raw prompt with BOS prepended (Gemma +// convention) and EOS appended; we do NOT apply chat templates — LTX +// uses the Gemma base text encoder on raw text. + +#ifndef __SD_TOKENIZERS_GEMMA3_TOKENIZER_H__ +#define __SD_TOKENIZERS_GEMMA3_TOKENIZER_H__ + +#include +#include +#include +#include + +class Gemma3Tokenizer { +public: + enum TokenType : uint8_t { + NORMAL = 1, + UNKNOWN = 2, + CONTROL = 3, + USER_DEFINED = 4, + BYTE = 5, + UNUSED = 6, + }; + + struct Piece { + std::string text; + float score = 0.0f; + uint8_t type = NORMAL; + }; + + // Load vocab + scores from a SentencePiece protobuf (*.model) file. + // Returns true on success. On failure, `error` holds a message. + bool load_from_spm(const std::string& path, std::string* error = nullptr); + + // Encode `text` into token ids. If `add_bos` is true, prepends the BOS + // id; if `add_eos`, appends EOS. + // + // Algorithm: byte-level pre-tokenization with the Gemma meta-space + // prefix ("▁"), then BPE merges driven by piece scores. Highest-score + // pair wins at each step. + std::vector encode(const std::string& text, + bool add_bos = true, + bool add_eos = false) const; + + // Decoding is not required for LTX use, but trivial enough to expose. + std::string decode(const std::vector& ids) const; + + int32_t bos_id() const { return bos_id_; } + int32_t eos_id() const { return eos_id_; } + int32_t pad_id() const { return pad_id_; } + int32_t unk_id() const { return unk_id_; } + int32_t vocab_size() const { return (int32_t)pieces_.size(); } + + const std::vector& pieces() const { return pieces_; } + +private: + std::vector pieces_; + std::unordered_map piece_to_id_; + int32_t bos_id_ = 2; + int32_t eos_id_ = 1; + int32_t pad_id_ = 0; + int32_t unk_id_ = 3; + + // Encodes a single pre-tokenised word into the BPE sequence. + void bpe_encode_word(const std::string& word, std::vector& out) const; +}; + +#endif // __SD_TOKENIZERS_GEMMA3_TOKENIZER_H__ diff --git a/src/vae.hpp b/src/vae.hpp index dc69535e8..634de0cfe 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -73,6 +73,9 @@ struct VAE : public GGMLRunner { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1; + } else if (sd_version_is_ltxv2(version)) { + // LTX VAE: patch_size=4 spatial, plus 3 down-blocks (x2 each) → 4 * 2^3 = 32. + scale_factor = 32; } return scale_factor; } diff --git a/tests/gemma3_tokenizer_test.cpp b/tests/gemma3_tokenizer_test.cpp new file mode 100644 index 000000000..292261963 --- /dev/null +++ b/tests/gemma3_tokenizer_test.cpp @@ -0,0 +1,49 @@ +// Manual test: tokenise a prompt with our Gemma-3 BPE tokenizer. +// +// Build: +// c++ -std=c++17 -O2 -Isrc \ +// tests/gemma3_tokenizer_test.cpp \ +// src/tokenizers/gemma3_tokenizer.cpp \ +// -o /tmp/gemma3_tok_test +// +// Run: +// /tmp/gemma3_tok_test /path/to/tokenizer.model "a cat walking across a grassy field" +// +// Compare output to the reference printed by +// python - <<'PY' +// from transformers import AutoTokenizer +// tok = AutoTokenizer.from_pretrained("google/gemma-3-12b-it") +// print(tok.encode("a cat walking across a grassy field")) +// PY + +#include +#include +#include + +#include "tokenizers/gemma3_tokenizer.h" + +int main(int argc, char** argv) { + if (argc < 3) { + std::fprintf(stderr, "usage: %s [add_eos=0|1]\n", argv[0]); + return 1; + } + Gemma3Tokenizer tok; + std::string err; + if (!tok.load_from_spm(argv[1], &err)) { + std::fprintf(stderr, "load failed: %s\n", err.c_str()); + return 2; + } + std::fprintf(stderr, "loaded %d pieces (bos=%d eos=%d pad=%d unk=%d)\n", + (int)tok.vocab_size(), (int)tok.bos_id(), + (int)tok.eos_id(), (int)tok.pad_id(), (int)tok.unk_id()); + + bool add_eos = argc > 3 && std::atoi(argv[3]) != 0; + auto ids = tok.encode(argv[2], /*add_bos=*/true, /*add_eos=*/add_eos); + for (size_t i = 0; i < ids.size(); ++i) { + std::printf("%s%d", i ? "," : "", ids[i]); + } + std::printf("\n"); + std::fprintf(stderr, "count=%zu decoded=\"%s\"\n", ids.size(), + tok.decode(ids).c_str()); + return 0; +}