Skip to content

fix(ltx2): cap encoder channels to match pretrained VAE weights#27

Open
silas-dsc wants to merge 1 commit intoBlaizzy:mainfrom
silas-dsc:fix/ltx2-encoder-channel-cap
Open

fix(ltx2): cap encoder channels to match pretrained VAE weights#27
silas-dsc wants to merge 1 commit intoBlaizzy:mainfrom
silas-dsc:fix/ltx2-encoder-channel-cap

Conversation

@silas-dsc
Copy link
Copy Markdown

Summary

  • The LTX-2.3-dev pretrained VAE encoder caps feature channels at 1024, but _make_encoder_block in mlx_video/models/ltx_2/video_vae/video_vae.py blindly applied out_channels = in_channels * multiplier. With the default config the second compress_all_res block tried to expand 1024 → 2048, while the actual conv weight (128, 3, 3, 3, 1024) only produces 128 × 8 = 1024 channels after space-to-depth.

  • This caused a broadcast error in SpaceToDepthDownsample.__call__ (return x_conv + x_in) on every I2V / two-stage HQ run that exercises the VAE encoder:

    ValueError: [broadcast_shapes] Shapes (1,1024,1,8,12) and (1,2048,1,8,12) cannot be broadcast.
    
  • Add an optional max_channels (default 1024) to the three compress_*_res block builders so out_channels = min(in_channels * multiplier, max_channels). This matches the pretrained weights (e.g. conv_out.conv.weight has in=1024, down_blocks.7.conv.conv.weight has out=128).

Reproduction

python -m mlx_video.models.ltx_2.generate \
  --pipeline dev-two-stage-hq \
  --model-repo prince-canuma/LTX-2.3-dev \
  --width 768 --height 512 --num-frames 81 \
  --image <some.jpg> --image-strength 0.5 \
  --prompt "..."

Before the fix, this crashes inside vae_encoder(stage1_image_tensor). After the fix, the encoder produces a (1, 128, 1, 8, 12) latent and the full pipeline runs end-to-end (verified locally — 81-frame 768x512 mp4 generated in ~18m, peak 44.88GB).

Test plan

  • Encode a (1, 3, 1, 256, 384) image through VideoEncoder.from_pretrained(...) — no shape errors, output shape (1, 128, 1, 8, 12).
  • End-to-end dev-two-stage-hq pipeline with I2V + A2V — completes and writes a valid mp4.
  • Spot-check that other LTX-2 model variants (where channels naturally stay under 1024) still load — the cap is a no-op for them.

🤖 Generated with Claude Code

The LTX-2.3-dev pretrained VAE encoder caps feature channels at 1024,
but `_make_encoder_block` blindly applied `out_channels = in_channels *
multiplier`. With the default config the second `compress_all_res`
block expanded 1024 → 2048 channels, while its actual conv weight
`(128, 3, 3, 3, 1024)` only produces 128 × 8 = 1024 channels after the
space-to-depth, causing a broadcast error in `x_conv + x_in`:

    ValueError: [broadcast_shapes] Shapes (1,1024,1,8,12) and
    (1,2048,1,8,12) cannot be broadcast.

Add an optional `max_channels` (default 1024) to the three
`compress_*_res` block builders so `out_channels = min(in_channels *
multiplier, max_channels)`. This matches the pretrained weights
(conv_out has in=1024) and unblocks I2V / two-stage HQ pipelines that
exercise the encoder.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants