Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
05e72a1
mtmd: add Gemma 4 audio conformer encoder support
stephencox Apr 4, 2026
2d0961e
gemma4: fix audio encoder and LM precision issues
stephencox Apr 7, 2026
6bfeb87
mtmd: use double-precision math for audio preprocessing constants
stephencox Apr 7, 2026
56c8304
Revert "mtmd: use double-precision math for audio preprocessing const…
stephencox Apr 7, 2026
59e4872
gguf-py: restore gemma3n mappings in tensor_mapping.py and fix swappe…
stephencox Apr 8, 2026
df3ec22
address ngxson review: fix tensor mapping in C++, remove dup comment,…
stephencox Apr 8, 2026
15f2272
address review: remove cross-load, keep per-layer scale, derive softcap
stephencox Apr 8, 2026
1cd0924
address review: auto-detect swapped conv norms, remove dup comment
stephencox Apr 8, 2026
a67decf
address review: simplify conv norm swap, move scaling to PR #21625
stephencox Apr 8, 2026
62986e6
mtmd: fix CUDA/Vulkan conformer encoder by making sigmoid input conti…
stephencox Apr 8, 2026
e3c6854
revert tensor_mapping.py, gemma4-iswa.cpp and llama-model.cpp changes
stephencox Apr 8, 2026
282a669
gemma4: restore BF16-rounded scales and per-layer multimodal scaling
stephencox Apr 8, 2026
9096859
mtmd: fix conv norm swap comment in gemma4 audio conformer
stephencox Apr 10, 2026
4107adf
revert gemma4-iswa.cpp changes per reviewer request
stephencox Apr 11, 2026
558399c
mtmd: swap conv_norm/norm_conv at load site instead of post-load
stephencox Apr 11, 2026
1389eea
mtmd: replace ggml_roll with view+concat for Metal compatibility
stephencox Apr 11, 2026
bf3beed
tests: fix unused parameter warning in test-llama-archs save_models
stephencox Apr 11, 2026
23fd8fc
tests: remove non-gemma4 skip list changes from test-llama-archs
stephencox Apr 11, 2026
af69fcc
Revert "mtmd: replace ggml_roll with view+concat for Metal compatibil…
stephencox Apr 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
}
}

Expand Down
21 changes: 20 additions & 1 deletion tests/test-llama-archs.cpp
Comment thread
stephencox-ict marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
uint32_t n_layer = 2;
if (arch == LLM_ARCH_LLAMA4) {
n_layer = 4; // hparams.n_no_rope_layer_step is hard-coded to 4
} else if (arch == LLM_ARCH_GEMMA4) {
n_embd = 128;
n_head = 2;
n_ff = 192;
n_layer = 5; // need at least 5 for swa_pattern (every 5th is full_attention)
} else if (arch == LLM_ARCH_GEMMA3N) {
n_embd = 64;
n_head = 1;
Expand Down Expand Up @@ -169,7 +174,15 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
ms.add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, uint32_t(8));
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, n_ctx/8);

if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
if (arch == LLM_ARCH_GEMMA4) {
ms.add_kv(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, n_embd/2);
ms.add_kv(LLM_KV_ATTENTION_SHARED_KV_LAYERS, uint32_t(0));
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, n_embd_head);
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, n_embd_head);
ms.add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, 10000.0f);
// SWA pattern: every 5th layer is full attention (matches E2B layer_types)
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, uint32_t(5));
} else if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
std::vector<uint32_t> pattern;
pattern.reserve(n_layer);
for (uint32_t il = 0; il < n_layer; il++) {
Expand Down Expand Up @@ -429,6 +442,9 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
for (bool moe : {false, true}) {
if (moe && !moe_implemented(arch)) {
continue;
Expand Down Expand Up @@ -510,6 +526,9 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}

const bool encode = arch == LLM_ARCH_T5 || arch == LLM_ARCH_DREAM || arch == LLM_ARCH_LLADA || arch == LLM_ARCH_LLADA_MOE || arch == LLM_ARCH_RND1;
for (bool moe : {false, true}) {
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(mtmd
models/cogvlm.cpp
models/conformer.cpp
models/dotsocr.cpp
models/gemma4a.cpp
models/gemma4v.cpp
models/glm4v.cpp
models/hunyuanocr.cpp
Expand Down
15 changes: 15 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,21 @@
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"

// gemma4 audio conformer
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
#define TN_A_INP_PROJ "a.input_projection.%s"
#define TN_A_CONV1D "a.conv1d.%d.%s"
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"

// mobilenetv5 (gemma3n) definitions
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"
Expand Down
16 changes: 16 additions & 0 deletions tools/mtmd/clip-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ struct clip_layer {
ggml_tensor * conv_pw2_w = nullptr;
ggml_tensor * conv_pw2_b = nullptr;

// gemma4 audio conformer per-layer
ggml_tensor * attn_pre_norm_w = nullptr;
ggml_tensor * attn_k_rel_w = nullptr;
ggml_tensor * per_dim_scale_w = nullptr;
ggml_tensor * per_dim_k_scale_w = nullptr;
ggml_tensor * ff_post_norm_1_w = nullptr;

bool has_deepstack() const {
return deepstack_fc1_w != nullptr;
}
Expand Down Expand Up @@ -459,6 +466,15 @@ struct clip_model {
};
std::map<std::string, clamp_info> clamp_info_map;

// gemma4 audio conformer
std::array<ggml_tensor *, 2> sscp_conv_w = {nullptr};
std::array<ggml_tensor *, 2> sscp_conv_b = {nullptr};
std::array<ggml_tensor *, 2> sscp_norm_w = {nullptr};
ggml_tensor * sscp_inp_proj_w = nullptr;
ggml_tensor * sscp_inp_proj_b = nullptr;
ggml_tensor * audio_out_proj_w = nullptr;
ggml_tensor * audio_out_proj_b = nullptr;

bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL
Expand Down
162 changes: 158 additions & 4 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_conformer>(ctx, img);
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
} break;
case PROJECTOR_TYPE_GLM4V:
{
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
Expand Down Expand Up @@ -1456,6 +1460,16 @@ struct clip_model_loader {
hparams.audio_window_len = 400;
hparams.audio_hop_len = 160;
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
// Gemma4 feature_extraction_gemma4.py:
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
hparams.audio_chunk_len = 0; // no fixed-length padding
hparams.audio_sample_rate = 16000;
hparams.audio_n_fft = 512;
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
hparams.audio_hop_len = 160;
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
hparams.image_pad_color = {127, 127, 127};
Expand Down Expand Up @@ -1558,16 +1572,21 @@ struct clip_model_loader {
}

// helper function
std::unordered_set<std::string> loaded_tensor_names;
auto get_tensor = [&](const std::string & name, bool required = true) {
// Each tensor should only be loaded once; duplicates indicate a bug
if (loaded_tensor_names.count(name)) {
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
}
Comment thread
stephencox-ict marked this conversation as resolved.
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
if (!cur && required) {
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
}
if (cur) {
tensors_to_load.push_back(cur);
// add tensors to context
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
ggml_set_name(data_tensor, cur->name);
loaded_tensor_names.insert(name);
cur = data_tensor;
}
return cur;
Expand Down Expand Up @@ -2159,6 +2178,76 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
for (int i = 0; i < 2; i++) {
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
}
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
// audio multimodal embedder (mm.a.* namespace, not mm.*)
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);

// Per-layer tensors NOT loaded by the generic loop above
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];

// Gemma4 audio conformer-specific tensors
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);

// Convolution module
// Note: conv_norm / norm_conv are swapped in GGUF due to
// upstream tensor_mapping.py, so we load them in reverse order
layer.norm_conv_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
layer.norm_conv_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
layer.conv_norm_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
layer.conv_norm_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);

// FFN2 (second half-step)
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
}

// Load clamp info for ClippableLinear AFTER all tensors are loaded
for (auto * tensor : tensors_to_load) {
std::string name = tensor->name;
if (string_ends_with(name, ".weight")) {
std::string name_inp_max = name;
std::string name_inp_min = name;
std::string name_out_max = name;
std::string name_out_min = name;
string_replace_all(name_inp_max, ".weight", ".input_max");
string_replace_all(name_inp_min, ".weight", ".input_min");
string_replace_all(name_out_max, ".weight", ".output_max");
string_replace_all(name_out_min, ".weight", ".output_min");
model.clamp_info_map[name] = {
get_scalar(name_inp_max, FLT_MAX),
get_scalar(name_inp_min, -FLT_MAX),
get_scalar(name_out_max, FLT_MAX),
get_scalar(name_out_min, -FLT_MAX)
};
}
}
} break;
case PROJECTOR_TYPE_LFM2A:
{
for (int i : {0, 2, 3, 5, 6}) {
Expand Down Expand Up @@ -2219,7 +2308,10 @@ struct clip_model_loader {
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
const size_t offset = tensor_offset[t->name];
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
Expand All @@ -2239,6 +2331,7 @@ struct clip_model_loader {

LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
}

}

struct support_info_op {
Expand Down Expand Up @@ -2511,8 +2604,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params

// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
// we can remove this check when we implement audio support for Gemma 3N
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
}

if (loader.has_audio && !skip_audio) {
Expand Down Expand Up @@ -2865,6 +2957,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
// O = floor((I - 1) / 2) + 1
int n = img->nx;
for (int i = 0; i < 2; i++) {
n = (n - 1) / 2 + 1;
}
n_patches = n;
} break;
default:
GGML_ABORT("unsupported projector type");
}
Expand Down Expand Up @@ -3323,6 +3425,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
set_input_i32("pos_w", pos_data);
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
GGML_ASSERT(imgs.entries.size() == 1);
const auto & img0 = imgs.entries.front();
// Compute n_pos matching SSCP output: two stride-2 convs
int n_pos = img0->nx;
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }

// Chunked local attention: blocked causal mask and RPE
const int chunk_size = 12;
const int max_past = 12;
const int context_size = chunk_size + max_past;
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;

// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
{
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
for (int b = 0; b < num_blocks; b++) {
for (int q = 0; q < chunk_size; q++) {
int gq = b * chunk_size + q;
for (int k = 0; k < context_size; k++) {
int gk = b * chunk_size - max_past + k;
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
}
}
}
}
set_input_f32("kq_mask", mask);
}

// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
{
const int n_embd = ctx->model.hparams.n_embd;
const int num_timescales = n_embd / 2;
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
const int rpe_len = max_past + 1;
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
for (int p = 0; p < rpe_len; p++) {
float position = (float)(max_past - p);
for (int i = 0; i < num_timescales; i++) {
float inv_ts = expf(-(float)i * log_timescale_increment);
float scaled = position * inv_ts;
pos_emb[p * n_embd + i] = sinf(scaled);
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
}
}
set_input_f32("pos_emb", pos_emb);
}
} break;
case PROJECTOR_TYPE_LFM2A:
{
GGML_ASSERT(imgs.entries.size() == 1);
Expand Down Expand Up @@ -3485,6 +3637,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2A:
return ctx->model.position_embeddings->ne[0];
case PROJECTOR_TYPE_GEMMA4A:
return ctx->model.hparams.projection_dim;
case PROJECTOR_TYPE_GLM4V:
return ctx->model.mm_ffn_down_w->ne[1];
default:
Expand Down
Loading