Skip to content

mtmd: add Gemma 4 audio conformer encoder support#21421

Merged
ngxson merged 19 commits intoggml-org:masterfrom
stephencox-ict:gemma4-audio-pr
Apr 12, 2026
Merged

mtmd: add Gemma 4 audio conformer encoder support#21421
ngxson merged 19 commits intoggml-org:masterfrom
stephencox-ict:gemma4-audio-pr

Conversation

@stephencox-ict
Copy link
Copy Markdown
Contributor

@stephencox-ict stephencox-ict commented Apr 4, 2026

Important

It is recommended to use BF16 mmproj. Other quantizations are known to have degraded performance; ref comment: #21421 (comment)

Overview

Add audio processing support for Gemma 4 models via a USM-style Conformer encoder.

Architecture:

  • 12-layer Conformer: FFN -> Self-Attention -> Causal Conv1D -> FFN -> Norm
  • Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
  • Chunked local attention with sinusoidal RPE (chunk_size=12, context_size=24)
  • Logit softcapping at 50.0, ClippableLinear with per-tensor clamping
  • Output projection -> RMSNorm -> multimodal embedder

Chunked local attention (matching PyTorch reference):

  • Q split into non-overlapping blocks of 12
  • K/V extracted as overlapping context windows of 24 via ggml_view_4d with stride 12
  • Per-block causal mask matching PyTorch's dist < left_window_size condition
  • Blocked relative position shift (Transformer-XL appendix B)
  • RPE: 13 sinusoidal position embeddings [12, 11, ..., 0]

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):

  • HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
  • Standard periodic Hann window (320 samples), zero-padded to FFT size
  • 30-second chunking (splits long audio into 30s segments)
  • Mel cosine similarity vs PyTorch: 1.0

Fixes (beyond the initial encoder):

  1. Contiguous sigmoid input (gemma4a.cpp): Wrap GLU gate view in ggml_cont() before ggml_sigmoid(). The non-contiguous view caused CUDA/Vulkan to fall back to CPU for sigmoid, creating 25 graph splits and numerical divergence on longer audio.

  2. Conv norm swap at load site (clip.cpp): The upstream tensor_mapping.py maps the gemma4 audio lconv1d norms with swapped GGUF names (conv_normnorm_conv). The loader now loads tensors in reverse order at the load site to correct this, rather than swapping post-load. Verified by element-wise comparison against Python transformers safetensors weights.

  3. Replace ggml_roll with ggml_view + ggml_concat (gemma4a.cpp): The ggml_roll op has no Metal kernel, causing 73 graph splits and CPU fallbacks on Apple Silicon. Replaced with equivalent view+concat sequences that are supported on all backends. Audio encoder graph splits reduced from 73 to 1 on Metal.

Usage:

llama-mtmd-cli \
    -m gemma-4-E2B-it-Q6_K.gguf \
    --mmproj mmproj-BF16.gguf \
    --audio sample.wav \
    -p "Transcribe this audio exactly." \
    --temp 1.0 --top-k 64 --top-p 0.95 \
    -ngl 99 --jinja

Audio transcription results (E2B, best across CPU/Vulkan/CUDA):

Short audio (5.9s LibriSpeech):

Ground truth: "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"
Output:       "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."

Long audio (17.4s, moon landing narration):

Output: "This is the moment when humanity's oldest dream became reality.
         This is the New York Times's boldest headline across America's newspaper record.
         The New York Times has documented our nation's most pivotal moments, but rarely
         has any story matched the cosmic significance of..."
Quant Short (5.9s) Long (17.4s)
BF16
Q8_0
UD-Q8_K_XL
Q6_K
Q5_K_M
Q5_K_S
Q4_0
IQ4_NL
Q4_K_M
Q4_K_S
Q4_1
IQ4_XS
Q3_K_M 🔶
Q3_K_S

E2B short audio: 14/14 PASS. All quantizations correctly transcribe across all backends.

⏳ = model's thinking block consumed all tokens before outputting the transcription. Higher -n values resolve this.

E4B also tested: 19/21 short PASS, 20/21 long PASS/PARTIAL. Only ultra-low 2-bit quants (UD-IQ2_M, UD-IQ3_XXS) fail. See PR comment for full E4B matrix.

Resampling note: For best audio quality, provide input already at 16kHz. Audio at other sample rates will be downsampled using miniaudio's linear resampler regardless of format (WAV, MP3, FLAC).

Generation parameters (from model's generation_config.json):
--temp 1.0 --top-k 64 --top-p 0.95

Additional information

Test plan:

  • test-mtmd-c-api passes
  • test-llama-archs passes
  • Full ctest suite (40/40 tests pass)
  • E2B all 14 quants: short audio PASS on CPU, Vulkan, CUDA
  • E2B BF16/Q6_K/Q8_0/Q5_K_M: long audio PASS on CPU, Vulkan, CUDA
  • E4B 19/21 quants: short audio PASS, 20/21 long PASS/PARTIAL
  • Mel values verified against PyTorch (cosine 1.0)
  • Encoder output verified against Python transformers (element-wise)
  • Conv norm mapping verified against safetensors weights
  • Tested with BF16 mmproj (recommended; F16/Q8_0 cause repetitions due to ClippableLinear sensitivity)

Dependency: #21625 (per-layer embedding scale for multimodal path) improves longer audio transcription reliability on longer audio (~17s+).

Ref: #21325
Related: #21599 (quantization fix: force Q6_K minimum for Gemma4 tied embeddings)
Related: #21612 (merged: perform per-layer projections in the first layer)
Related: #21625 (dependency: per-layer embedding scale for multimodal path)

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES - Claude Code and Gemini were used in an assistive capacity for iterative debugging (tensor comparison, mel spectrogram verification, conformer layer tracing, quantization analysis) and code review. All architecture decisions, algorithm implementations, and code were manually reviewed and verified against the PyTorch reference.

@github-actions github-actions Bot added documentation Improvements or additions to documentation examples labels Apr 4, 2026
@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented Apr 4, 2026

Hi @stephencox-ict, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • AI-generated content: This project does not accept PRs, descriptions or commit messages that are fully or predominantly AI-generated. If you have used AI to assist you in writing code, please make sure to disclose that explicitly.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

Comment thread tools/mtmd/mtmd-audio.cpp
Comment thread tools/mtmd/mtmd-audio.cpp Outdated
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 4, 2026

Nice, seems to work but not 100% correct (using e4b, f16):

  1. Analyze the Request: The user wants me to transcribe the provided text.
  2. Analyze the Input Text: The text is a highly dramatic, repetitive, and emphatic piece of writing, likely intended as an urgent news headline or intro, with some slight repetition/redundancy in phrasing.
  3. Perform Transcription: I will transcribe the text exactly as it is written, preserving capitalization, punctuation, and structure.

Self-Correction/Verification: The input is slightly fragmented due to the rapid, breathless style, but the goal is faithful transcription.

  1. Final Output Generation.<channel|>"The Man on the Moon declared the New York Times from July 20th, 1969. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times, from July 20th, 1969. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times."

However, the correct transcription should be:

The new york times from july 21 1969. this isn't just newsprint and ink this is the moment when humanities oldest dream became front page reality men walk on moon declares the bold headline across america's newspaper of record for over a century the new york times has documented our nation's most pivotal moments but rarely has any story matched the cosmic significance of this one.

Comment thread README.md Outdated
Comment thread tools/mtmd/mtmd.cpp Outdated
Comment thread tools/mtmd/mtmd-audio.h Outdated
Comment thread docs/multimodal/gemma4.md Outdated
@stephencox-ict
Copy link
Copy Markdown
Contributor Author

Nice, seems to work but not 100% correct (using e4b, f16):

  1. Analyze the Request: The user wants me to transcribe the provided text.
  2. Analyze the Input Text: The text is a highly dramatic, repetitive, and emphatic piece of writing, likely intended as an urgent news headline or intro, with some slight repetition/redundancy in phrasing.
  3. Perform Transcription: I will transcribe the text exactly as it is written, preserving capitalization, punctuation, and structure.

Self-Correction/Verification: The input is slightly fragmented due to the rapid, breathless style, but the goal is faithful transcription.

  1. Final Output Generation.<channel|>"The Man on the Moon declared the New York Times from July 20th, 1969. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times, from July 20th, 1969. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times. This isn't just a news event. This is the moment when humanity's oldest dream became a reality. The Man on the Moon declared the New York Times."

However, the correct transcription should be:

The new york times from july 21 1969. this isn't just newsprint and ink this is the moment when humanities oldest dream became front page reality men walk on moon declares the bold headline across america's newspaper of record for over a century the new york times has documented our nation's most pivotal moments but rarely has any story matched the cosmic significance of this one.

I haven't yet implemented chunked local self-attention. Focussed on the testing side now and will come back to this

@github-actions github-actions Bot added the testing Everything test related label Apr 4, 2026
@stephencox-ict stephencox-ict force-pushed the gemma4-audio-pr branch 3 times, most recently from 83d1f37 to 13e9f5e Compare April 4, 2026 22:53
Comment thread tools/mtmd/clip.cpp
@stephencox-ict stephencox-ict force-pushed the gemma4-audio-pr branch 3 times, most recently from 29dd32e to 7435a59 Compare April 5, 2026 00:15
@stephencox-ict stephencox-ict marked this pull request as ready for review April 5, 2026 00:16
@stephencox-ict stephencox-ict requested review from a team and JohannesGaessler as code owners April 5, 2026 00:16
@stephencox-ict stephencox-ict requested a review from ngxson April 5, 2026 00:19
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to test-llama-archs.cpp LGTM otherwise. For some of the other files I'm seeing though that you are adding code comments with EM dashes. Please stick to ASCII unless there is a good reason not to.

Comment thread tests/test-llama-archs.cpp Outdated
@stephencox-ict
Copy link
Copy Markdown
Contributor Author

The changes to test-llama-archs.cpp LGTM otherwise. For some of the other files I'm seeing though that you are adding code comments with EM dashes. Please stick to ASCII unless there is a good reason not to.

Fixed

Comment thread tools/mtmd/models/gemma4a.cpp Outdated
Comment thread tools/mtmd/clip-impl.h Outdated
Comment thread tools/mtmd/clip-model.h Outdated
Comment thread tools/mtmd/clip.cpp Outdated
Comment thread tools/mtmd/clip.cpp Outdated
Comment thread tools/mtmd/clip.cpp Outdated
Comment thread tools/mtmd/mtmd-audio.cpp Outdated
Comment thread tools/mtmd/mtmd-audio.cpp Outdated
stephencox and others added 2 commits April 12, 2026 09:56
Instead of loading tensors into the wrong fields and swapping
afterwards, load them directly into the correct fields by using
the reversed GGUF tensor names at the loading site. This is
cleaner and removes the need for the post-load swap loop.

Addresses review comment from ngxson on 2026-04-11.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace ggml_roll operations in the Gemma 4 audio conformer with
equivalent ggml_view + ggml_concat sequences. The ROLL op has no
Metal kernel, causing 73 graph splits and CPU fallbacks on Apple
Silicon that likely cause the repetitive output reported by ngxson.

With this change, all conformer ops run on a single backend
(graph splits reduced from 73 to 1).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 11, 2026

I don't think replacing ggml_roll with something else is a valid solution. Unsupported ops are fallback to CPU implementation.

If your impl work on CUDA but fails on CPU (via --no-mmproj-offload), either CUDA or CPU is wrong.

@stephencox
Copy link
Copy Markdown

@ngxson CUDA output gist: https://gist.github.com/stephencox/9ccab7b860d5be9c7f8df97b9e9f9525

I investigated the repetition issue and found the root cause: ggml_roll has no Metal kernel, which was causing 73 graph splits in the audio conformer encoder (CPU fallbacks for every ROLL op). This likely caused numerical issues from the GPU↔CPU data marshaling across split boundaries.

Fix (1389eea): Replaced both ggml_roll calls in the conformer with equivalent ggml_view + ggml_concat sequences, which are supported on all backends including Metal.

Results:

  • Audio encoder graph splits: 73 → 1 (no more CPU fallbacks)
  • Tested on CPU and CUDA (RTX 3060) — both produce correct transcriptions
  • The unsupported ROLL ops warning should no longer appear on Metal

Could you try the latest commit and see if the repetitions are resolved on your M5 Max?

@stephencox
Copy link
Copy Markdown

Also a reminder to use the latest Unsloth GGUFs (the earlier conversions had issues):

@stephencox
Copy link
Copy Markdown

Fair point — I tested with --no-mmproj-offload on CUDA and it works correctly (CPU-only audio encoder, graph splits = 1, correct transcription). So the CPU ROLL implementation itself is fine.

The view+concat replacement still has value as a performance improvement (eliminates 73 graph splits on Metal, keeping the entire conformer on one backend), but you are right that it does not explain the repetitions if CPU fallback works correctly.

Could the repetition be related to the model/GGUF version? I have not been able to reproduce it on E2B BF16 (Unsloth). What model and mmproj are you using?

Restore the target_arch filter that was accidentally removed when
adding per-arch skip lists. Also remove redundant LLM_ARCH_UNKNOWN
check that was already handled above.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 11, 2026

Hmm ok, thanks for the pointer. I tried the unsloth version (Q4_K_M text + BF16 mmproj) and it's indeed working without repetition.

I downloaded a fresh copy of https://huggingface.co/google/gemma-4-E4B-it and re-convert it. Turns out, the mmproj is very sensitive to quantization:

  • BF16: works
  • F16: repetition
  • Q8_0: repetition

So I think for now, the only way is to keep BF16 for mmproj. I hope that will also fix some problems with image input (to be tested)

Comment thread tests/test-llama-archs.cpp
Comment thread tools/mtmd/models/gemma4a.cpp Outdated
@stephencox
Copy link
Copy Markdown

Thanks for confirming. This matches what we found during validation — the Gemma 4 audio conformer uses ClippableLinear layers with tight clamp boundaries (e.g., input_max=12.8125). When weights are quantized from BF16 to F16 or lower, small numerical drift causes activations to hit/miss the clamp thresholds differently, which gets amplified through the 12 conformer layers (we measured cosine similarity dropping from 0.999 at layer 0 to 0.685 at layer 11).

So BF16 mmproj is required for now. The Unsloth GGUFs ship with BF16 mmproj which is why they work.

For the PR, should I add a note/warning about this in the code or docs?

@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 11, 2026

For the PR, should I add a note/warning about this in the code or docs?

I think adding a note on top of this PR should be fine, no need to add to docs or comment. Something like:

Important

It is recommended to use BF16 mmproj. Other quantizations are known to have degraded performance; ref comment: #21421 (comment)

stephencox and others added 2 commits April 12, 2026 11:01
Keep only the gemma4-specific fixture params and skip entries.
The other arch skip lists (CLIP, GPTJ, CHAMELEON, RWKV, BERT,
PLM, WAVTOKENIZER_DEC, etc.) are unrelated to this PR.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @CISC @pwilkin if someone can give an approval, thanks!

@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 12, 2026

For ref, the repetition seems to be due to causal attention being set incorrectly on the vision model. It should be fixed in #21824 ; I tested with Q8_0 mmproj and it works correctly now

oussamaahmia pushed a commit to oussamaahmia/llama-cpp-turboquant-gemma4 that referenced this pull request Apr 13, 2026
* mtmd: add Gemma 4 audio conformer encoder support

Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer.

Architecture:
- 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm
- Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
- Full self-attention with sinusoidal RPE and sliding window mask (24)
- Logit softcapping at 50.0, ClippableLinear clamping
- Output: 1024 → 1536 → RMSNorm → multimodal embedder

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):
- HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
- Standard periodic Hann window (320 samples), zero-padded to FFT size
- Semicausal left-padding (frame_length/2 samples)
- Frame count matched to PyTorch (unfold formula)
- No pre-emphasis, no Whisper-style normalization
- Mel cosine similarity vs PyTorch: 0.9998

Key fixes:
- Tensor loading dedup: prevent get_tensor() from creating duplicate
  entries in ctx_data. Fixed with std::set guard.
- ClippableLinear clamp_info loading moved after per-layer tensors.
- Sliding window mask (24 positions) matching PyTorch context_size.
- Skip Whisper normalization for Gemma4 mel output.

Tested on E2B and E4B with CPU and Vulkan backends.
Transcribes: "Glad to see things are going well and business is starting
to pick up" (matching ground truth).

Ref: ggml-org#21325
crodjer added a commit to crodjer/llama.cpp that referenced this pull request Apr 13, 2026
* origin/master:
  webui: MCP Diagnostics improvements (ggml-org#21803)
  Remove extra conditional check on debug mode. (ggml-org#21798)
  sycl: disable Q1_0 in backend and cleanup unused variables (ggml-org#21807)
  mtmd: fix crash when sending image under 2x2 pixels (ggml-org#21711)
  mtmd: qwen3 audio support (qwen3-omni and qwen3-asr) (ggml-org#19441)
  convert : force f16 or f32 on step3-vl conv weights (ggml-org#21646)
  mtmd: add gemma 4 test (vision + audio) [no ci] (ggml-org#21806)
  mtmd: add Gemma 4 audio conformer encoder support (ggml-org#21421)
  fix: Proper messages rendering for "Show raw output" (ggml-org#21672)
  docs: add guide on how to add multimodal support (ggml-org#21778)
HermestoAizales pushed a commit to HermestoAizales/llama.cpp that referenced this pull request Apr 13, 2026
* mtmd: add Gemma 4 audio conformer encoder support

Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer.

Architecture:
- 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm
- Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
- Full self-attention with sinusoidal RPE and sliding window mask (24)
- Logit softcapping at 50.0, ClippableLinear clamping
- Output: 1024 → 1536 → RMSNorm → multimodal embedder

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):
- HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
- Standard periodic Hann window (320 samples), zero-padded to FFT size
- Semicausal left-padding (frame_length/2 samples)
- Frame count matched to PyTorch (unfold formula)
- No pre-emphasis, no Whisper-style normalization
- Mel cosine similarity vs PyTorch: 0.9998

Key fixes:
- Tensor loading dedup: prevent get_tensor() from creating duplicate
  entries in ctx_data. Fixed with std::set guard.
- ClippableLinear clamp_info loading moved after per-layer tensors.
- Sliding window mask (24 positions) matching PyTorch context_size.
- Skip Whisper normalization for Gemma4 mel output.

Tested on E2B and E4B with CPU and Vulkan backends.
Transcribes: "Glad to see things are going well and business is starting
to pick up" (matching ground truth).

Ref: ggml-org#21325
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
* mtmd: add Gemma 4 audio conformer encoder support

Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer.

Architecture:
- 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm
- Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm
- Full self-attention with sinusoidal RPE and sliding window mask (24)
- Logit softcapping at 50.0, ClippableLinear clamping
- Output: 1024 → 1536 → RMSNorm → multimodal embedder

Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a):
- HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3
- Standard periodic Hann window (320 samples), zero-padded to FFT size
- Semicausal left-padding (frame_length/2 samples)
- Frame count matched to PyTorch (unfold formula)
- No pre-emphasis, no Whisper-style normalization
- Mel cosine similarity vs PyTorch: 0.9998

Key fixes:
- Tensor loading dedup: prevent get_tensor() from creating duplicate
  entries in ctx_data. Fixed with std::set guard.
- ClippableLinear clamp_info loading moved after per-layer tensors.
- Sliding window mask (24 positions) matching PyTorch context_size.
- Skip Whisper normalization for Gemma4 mel output.

Tested on E2B and E4B with CPU and Vulkan backends.
Transcribes: "Glad to see things are going well and business is starting
to pick up" (matching ground truth).

Ref: ggml-org#21325
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants