From 1f0df18e0d274d237201dc212dfc125cc8d85920 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 28 Nov 2025 14:01:50 +0100 Subject: [PATCH 1/7] convert gguf --- convert_hf_to_gguf.py | 31 +++++++++++++++++++++++++++++-- gguf-py/gguf/constants.py | 3 +++ gguf-py/gguf/gguf_writer.py | 6 ++++++ gguf-py/gguf/tensor_mapping.py | 12 ++++++++++++ tools/mtmd/clip-impl.h | 2 ++ tools/mtmd/clip.cpp | 5 +++++ 6 files changed, 57 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 866aa536f19..eb0076c3d2a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3348,7 +3348,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) -@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "DotsOCRForCausalLM") class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 @@ -3374,7 +3374,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("language_model.", "") # for InternVL if name.startswith("mlp") or name.startswith("multi_modal_projector") \ or name.startswith("vision_model") or name.startswith("audio_tower") \ - or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"): + or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") \ + or name.startswith("vision_tower."): # skip vision and audio tensors return [] yield from super().modify_tensors(data_torch, name, bid) @@ -10074,6 +10075,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] +@ModelBase.register("DotsOCRForCausalLM") +class DotsOCRVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = 0 # dynamic resolution + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DOTSOCR) + self.gguf_writer.add_vision_image_min_pixels(self.preprocessor_config["min_pixels"]) + self.gguf_writer.add_vision_image_max_pixels(self.preprocessor_config["max_pixels"]) + self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["rms_norm_eps"])) + self.gguf_writer.add_vision_projector_scale_factor(self.find_vparam(["spatial_merge_size"])) + self.gguf_writer.add_vision_use_silu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("vision_tower."): + print(name) + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 266d19f9dd7..3968ca0a687 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -280,6 +280,8 @@ class Clip: class ClipVision: IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -3304,6 +3306,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + DOTSOCR = "dots_ocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 57ca2035fe2..28bd106f931 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1118,6 +1118,12 @@ def add_vision_n_wa_pattern(self, value: int) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_image_max_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value) + + def add_vision_image_min_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a7b09739791..c5437bd023d 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1189,6 +1189,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "vision_tower.merger.mlp.{bid}", # dots.ocr ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1225,6 +1226,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "vision_tower.patch_embed.patchifier.proj", # dots.ocr ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1240,6 +1242,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl + "vision_tower.blocks.{bid}.attn.qkv", # dots.ocr "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm ), @@ -1301,6 +1304,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "vision_tower.blocks.{bid}.norm1", # dots.ocr ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1316,6 +1320,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "vision_tower.blocks.{bid}.attn.proj", # dots.ocr ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1330,6 +1335,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "vision_tower.blocks.{bid}.norm2", # dots.ocr ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1345,12 +1351,14 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "vision_tower.blocks.{bid}.mlp.fc2", # dots.ocr ), MODEL_TENSOR.V_ENC_FFN_GATE: ( "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + "vision_tower.blocks.{bid}.mlp.fc1", # dots.ocr ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( @@ -1366,6 +1374,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "vision_tower.blocks.{bid}.mlp.fc3", # dots.ocr ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1383,6 +1392,7 @@ class TensorNameMap: "vision_tower.ln_pre", # pixtral-hf "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 + "vision_tower.patch_embed.patchifier.norm", # dots.ocr ), MODEL_TENSOR.V_POST_NORM: ( @@ -1391,6 +1401,7 @@ class TensorNameMap: "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl + "vision_tower.post_trunk_norm", # dots.ocr ), MODEL_TENSOR.V_MM_INP_PROJ: ( @@ -1403,6 +1414,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "vision_tower.merger.ln_q", # dots.ocr ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a..f1c9ed47f37 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -32,6 +32,8 @@ // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" +#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 52ea542decc..cadee151496 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2663,6 +2663,11 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels, false); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels, false); + if (hparams.image_size == 0 && hparams.image_min_pixels == -1 && hparams.image_max_pixels == -1) { + throw std::runtime_error("one of: image_size, image_min_pixels, and image_max_pixels must be defined\n"); + } get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy From 9149ff70f11c165c16e5b961211bd2d9ef623360 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 28 Nov 2025 15:35:30 +0100 Subject: [PATCH 2/7] clip impl --- convert_hf_to_gguf.py | 1 - gguf-py/gguf/tensor_mapping.py | 4 +- tools/mtmd/clip-impl.h | 2 + tools/mtmd/clip.cpp | 176 +++++++++++++++++++++++++++------ tools/mtmd/mtmd.cpp | 5 + 5 files changed, 153 insertions(+), 35 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eb0076c3d2a..18cda2870e5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10095,7 +10095,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter del bid # unused if name.startswith("vision_tower."): - print(name) return [(self.map_tensor_name(name), data_torch)] return [] # skip other tensors diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index c5437bd023d..f708b03cafa 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1351,7 +1351,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm - "vision_tower.blocks.{bid}.mlp.fc2", # dots.ocr + "vision_tower.blocks.{bid}.mlp.fc3", # dots.ocr ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1374,7 +1374,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm - "vision_tower.blocks.{bid}.mlp.fc3", # dots.ocr + "vision_tower.blocks.{bid}.mlp.fc2", # dots.ocr ), MODEL_TENSOR.V_LAYER_SCALE_1: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index f1c9ed47f37..9b25553e65a 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -158,6 +158,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_DOTS_OCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -184,6 +185,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_DOTS_OCR, "dots_ocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index cadee151496..c6e9bd97609 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -381,6 +381,7 @@ struct clip_model { // pixtral ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; + ggml_tensor * mm_patch_merger_b = nullptr; // ultravox / whisper encoder ggml_tensor * conv1d_1_w = nullptr; @@ -1839,15 +1840,7 @@ struct clip_graph { if (model.audio_has_stack_frames()) { // StackAudioFrames // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); + cur = build_stacked_embeddings(cur, hparams.proj_stack_factor); cb(cur, "after_stacked", -1); } @@ -1991,6 +1984,48 @@ struct clip_graph { return gf; } + ggml_cgraph * build_dots_ocr() { + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, false); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_RMS, + hparams.ffn_op, + nullptr, // no learned pos embd + add_pos); + + // dots.ocr patch merger + projector + { + GGML_ASSERT(hparams.n_merge > 0); + cur = build_norm(cur, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, 1e-6, -1); + cur = build_stacked_embeddings(cur, hparams.n_merge * hparams.n_merge); + cb(cur, "after_patch_merger", -1); + cur = build_ffn(cur, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, // no gate + model.mm_2_w, model.mm_2_b, + FFN_GELU, -1); + cb(cur, "after_projector", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + private: // // utility functions @@ -2065,34 +2100,69 @@ struct clip_graph { // self-attention { - ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); - if (layer.q_b) { - Qcur = ggml_add(ctx0, Qcur, layer.q_b); - } + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (layer.qkv_w != nullptr) { + // fused qkv + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + if (layer.qkv_b != nullptr) { + cur = ggml_add(ctx0, cur, layer.qkv_b); + } - ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); - if (layer.k_b) { - Kcur = ggml_add(ctx0, Kcur, layer.k_b); - } + Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + d_head * ggml_element_size(cur), cur->nb[1], + /* offset */ 0); - ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); - if (layer.v_b) { - Vcur = ggml_add(ctx0, Vcur, layer.v_b); - } + Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + d_head * ggml_element_size(cur), cur->nb[1], + /* offset */ n_embd * ggml_element_size(cur)); - if (layer.q_norm) { - Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); - cb(Qcur, "Qcur_norm", il); - } + Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + d_head * ggml_element_size(cur), cur->nb[1], + /* offset */ 2 * n_embd * ggml_element_size(cur)); - if (layer.k_norm) { - Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); - cb(Kcur, "Kcur_norm", il); - } + if (layer.q_norm) { + Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); + cb(Qcur, "Qcur_norm", il); + } - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); + if (layer.k_norm) { + Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); + cb(Kcur, "Kcur_norm", il); + } + + } else { + // separate q, k, v + Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + + Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + + Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } + + if (layer.q_norm) { + Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); + cb(Qcur, "Qcur_norm", il); + } + + if (layer.k_norm) { + Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); + cb(Kcur, "Kcur_norm", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -2362,6 +2432,20 @@ struct clip_graph { return cur; } + // stack N consecutive rows into one row + ggml_tensor * build_stacked_embeddings(ggml_tensor * cur, int n_stack) { + int64_t stride = cur->ne[0] * n_stack; + int64_t padded_len = CLIP_ALIGN(ggml_nelements(cur), stride); + int64_t pad = padded_len - ggml_nelements(cur); + if (pad > 0) { + cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + return cur; + } + // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 @@ -2524,6 +2608,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_cogvlm(); } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + res = graph.build_dots_ocr(); + } break; default: { res = graph.build_llava(); @@ -2838,6 +2926,12 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -3244,6 +3338,15 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -4318,6 +4421,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: + case PROJECTOR_TYPE_DOTS_OCR: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized_image; @@ -4594,6 +4698,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + // dynamic size + int n_stack = params.n_merge * params.n_merge; + n_patches = CLIP_ALIGN(n_patches, n_stack) / n_stack; + } break; default: GGML_ABORT("unsupported projector type"); } @@ -4870,6 +4980,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_LIGHTONOCR: + case PROJECTOR_TYPE_DOTS_OCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -5003,6 +5114,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: + case PROJECTOR_TYPE_DOTS_OCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dfad9cd7957..f22703cff62 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -304,6 +304,11 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_DOTS_OCR) { + // <|img|> ... (image embeddings) ... <|endofimg|> + img_beg = "<|img|>"; + img_end = "<|endofimg|>"; + } } From 524a9a4005f621c878cc36b3dd2537cd151d3528 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 5 Apr 2026 17:05:31 +0200 Subject: [PATCH 3/7] fix conversion --- convert_hf_to_gguf.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 709382b4871..ef4ae9cbec8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -12680,19 +12680,23 @@ def __init__(self, *args, **kwargs): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DOTSOCR) - self.gguf_writer.add_vision_image_min_pixels(self.preprocessor_config["min_pixels"]) - self.gguf_writer.add_vision_image_max_pixels(self.preprocessor_config["max_pixels"]) + self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"]) + self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"]) self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["rms_norm_eps"])) self.gguf_writer.add_vision_projector_scale_factor(self.find_vparam(["spatial_merge_size"])) self.gguf_writer.add_vision_use_silu(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - if name.startswith("vision_tower."): - return [(self.map_tensor_name(name), data_torch)] - - return [] # skip other tensors + if "vision_tower.blocks." in name and ".mlp." in name: + # note: to avoid naming conflicts in tensor_mapping.py, we need to handle FFN renaming here + # fc1 -> gate, fc2 -> up, fc3 -> down + # mapping original names to Qwen2.5 naming scheme + name = name.replace("vision_tower.blocks.", "visual.blocks.") + name = name.replace(".fc1", ".gate_proj") + name = name.replace(".fc2", ".up_proj") + name = name.replace(".fc3", ".down_proj") + yield from super().modify_tensors(data_torch, name, bid) ###### CONVERSION LOGIC ###### From d3067ccf4bc6f9982b923ea213512045cc86c998 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 5 Apr 2026 17:47:23 +0200 Subject: [PATCH 4/7] wip --- convert_hf_to_gguf.py | 8 +++-- gguf-py/gguf/tensor_mapping.py | 3 +- tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip.cpp | 53 ++++++++++++++++++++++++++++++++++ tools/mtmd/models/dotsocr.cpp | 49 +++++++++++++++++++++++++++++++ tools/mtmd/models/models.h | 5 ++++ tools/mtmd/mtmd.cpp | 7 +++++ 7 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 tools/mtmd/models/dotsocr.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ef4ae9cbec8..ffe30359a8f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -12690,12 +12690,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("vision_tower."): if "vision_tower.blocks." in name and ".mlp." in name: # note: to avoid naming conflicts in tensor_mapping.py, we need to handle FFN renaming here - # fc1 -> gate, fc2 -> up, fc3 -> down + # x = F.silu(self.fc1(x)) * self.fc3(x) + # x = self.fc2(x) + # fc1 -> gate, fc2 -> down, fc3 -> up # mapping original names to Qwen2.5 naming scheme name = name.replace("vision_tower.blocks.", "visual.blocks.") name = name.replace(".fc1", ".gate_proj") - name = name.replace(".fc2", ".up_proj") - name = name.replace(".fc3", ".down_proj") + name = name.replace(".fc2", ".down_proj") + name = name.replace(".fc3", ".up_proj") yield from super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2aaa0372fc3..63ee664e61e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1404,11 +1404,12 @@ class TensorNameMap: "siglip2.vision_model.embeddings.patch_embedding", "vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL "model.vision_tower.patch_embedder.input_proj", # gemma4 + "vision_tower.patch_embed.patchifier.proj", # dots.ocr ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( "visual.post_conv_layernorm", # glm4v - "vision_tower.patch_embed.patchifier.proj", # dots.ocr + "vision_tower.patch_embed.patchifier.norm", # dots.ocr ), MODEL_TENSOR.V_ENC_EMBD_POS: ( diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 675464c6b5f..244a835c2c6 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(mtmd models/models.h models/cogvlm.cpp models/conformer.cpp + models/dotsocr.cpp models/gemma4v.cpp models/glm4v.cpp models/internvl.cpp diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 12517123e7c..11c9f9937b8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -853,6 +853,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { @@ -1261,6 +1265,14 @@ struct clip_model_loader { get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_KIMIVL: { hparams.image_resize_algo = RESIZE_ALGO_BILINEAR; @@ -1948,6 +1960,15 @@ struct clip_model_loader { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false); } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); + } break; case PROJECTOR_TYPE_ULTRAVOX: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -2701,6 +2722,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PADDLEOCR: + case PROJECTOR_TYPE_DOTS_OCR: { // dynamic size int n_merge = ctx->model.hparams.n_merge; @@ -2990,6 +3012,36 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + set_input_i32("positions", positions); + } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + const int merge_ratio = hparams.n_merge; + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + + // For dots.ocr we need [total_patches, 2] -> flattened as (h_pos, w_pos) pairs + const int n_pos = ph * pw; + std::vector positions(n_pos * 4); + int ptr = 0; + + // 4 nested loops like GLM-4V, but emitting (y, x) pairs instead of duplicating + for (int y = 0; y < ph; y += merge_ratio) { + for (int x = 0; x < pw; x += merge_ratio) { + for (int dy = 0; dy < merge_ratio; dy++) { + for (int dx = 0; dx < merge_ratio; dx++) { + const int ypos = y + dy; + const int xpos = x + dx; + positions[ptr * 2 + 0] = ypos; // height position + positions[ptr * 2 + 1] = xpos; // width position + positions[ptr * 2 + 2] = ypos; + positions[ptr * 2 + 3] = xpos; + ptr++; + } + } + } + } + set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: @@ -3306,6 +3358,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: + case PROJECTOR_TYPE_DOTS_OCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/models/dotsocr.cpp b/tools/mtmd/models/dotsocr.cpp new file mode 100644 index 00000000000..9789dba29be --- /dev/null +++ b/tools/mtmd/models/dotsocr.cpp @@ -0,0 +1,49 @@ +#include "models.h" + +ggml_cgraph * clip_graph_dotsocr::build() { + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + // note: similar to PaddleOCR + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi( + ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, 10000, 1, 0, 1, 32, 1); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, + add_pos); + + cb(cur, "vit_out", -1); + + // dots.ocr patch merger + projector + { + GGML_ASSERT(hparams.n_merge > 0); + cur = build_norm(cur, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, 1e-6, -1); + cur = build_stack(cur, hparams.n_merge * hparams.n_merge, n_embd); + cb(cur, "after_patch_merger", -1); + cur = build_ffn(cur, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, // no gate + model.mm_2_w, model.mm_2_b, + FFN_GELU, -1); + cb(cur, "after_projector", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 992eda04bbd..53a87419839 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -68,6 +68,11 @@ struct clip_graph_paddleocr : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_dotsocr : clip_graph { + clip_graph_dotsocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_cogvlm : clip_graph { clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 35b4396fd87..2be0cf67cc9 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -359,6 +359,13 @@ struct mtmd_context { img_end = "<|im_end|>"; image_preproc = std::make_unique(ctx_v); } break; + case PROJECTOR_TYPE_DOTS_OCR: + { + // <|img|> ... (image embeddings) ... <|endofimg|> + img_beg = "<|img|>"; + img_end = "<|endofimg|>"; + image_preproc = std::make_unique(ctx_v); + } break; case PROJECTOR_TYPE_NEMOTRON_V2_VL: { image_preproc = std::make_unique(ctx_v); From d2befc0c4688a991734d10f9fa9d106521b68eea Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 5 Apr 2026 20:38:40 +0200 Subject: [PATCH 5/7] corrections --- tools/mtmd/clip.cpp | 28 +++++++++++----------------- tools/mtmd/models/dotsocr.cpp | 6 +++--- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 11c9f9937b8..cdf0e57fbfc 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1968,6 +1968,8 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); + // post_trunk_norm: applied after all ViT blocks, before the merger + model.post_ln_w = get_tensor(string_format(TN_MM_POST_NORM, "weight")); } break; case PROJECTOR_TYPE_ULTRAVOX: { @@ -3016,29 +3018,21 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_DOTS_OCR: { - const int merge_ratio = hparams.n_merge; const int pw = image_size_width / patch_size; const int ph = image_size_height / patch_size; - - // For dots.ocr we need [total_patches, 2] -> flattened as (h_pos, w_pos) pairs const int n_pos = ph * pw; std::vector positions(n_pos * 4); int ptr = 0; - // 4 nested loops like GLM-4V, but emitting (y, x) pairs instead of duplicating - for (int y = 0; y < ph; y += merge_ratio) { - for (int x = 0; x < pw; x += merge_ratio) { - for (int dy = 0; dy < merge_ratio; dy++) { - for (int dx = 0; dx < merge_ratio; dx++) { - const int ypos = y + dy; - const int xpos = x + dx; - positions[ptr * 2 + 0] = ypos; // height position - positions[ptr * 2 + 1] = xpos; // width position - positions[ptr * 2 + 2] = ypos; - positions[ptr * 2 + 3] = xpos; - ptr++; - } - } + // flat layout: [h, w, h, w] for each patch + // patches are in raster order (matching conv2d output) + for (int y = 0; y < ph; y++) { + for (int x = 0; x < pw; x++) { + positions[ ptr] = y; + positions[ n_pos + ptr] = x; + positions[2*n_pos + ptr] = y; + positions[3*n_pos + ptr] = x; + ptr++; } } diff --git a/tools/mtmd/models/dotsocr.cpp b/tools/mtmd/models/dotsocr.cpp index 9789dba29be..92974bb670d 100644 --- a/tools/mtmd/models/dotsocr.cpp +++ b/tools/mtmd/models/dotsocr.cpp @@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_dotsocr::build() { ggml_tensor * inp = build_inp(); ggml_tensor * cur = build_vit( inp, n_patches, - NORM_TYPE_NORMAL, + NORM_TYPE_RMS, hparams.ffn_op, nullptr, add_pos); @@ -32,13 +32,13 @@ ggml_cgraph * clip_graph_dotsocr::build() { { GGML_ASSERT(hparams.n_merge > 0); cur = build_norm(cur, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, 1e-6, -1); - cur = build_stack(cur, hparams.n_merge * hparams.n_merge, n_embd); + cur = build_patch_merge_permute(cur, hparams.n_merge); cb(cur, "after_patch_merger", -1); cur = build_ffn(cur, model.mm_0_w, model.mm_0_b, nullptr, nullptr, // no gate model.mm_2_w, model.mm_2_b, - FFN_GELU, -1); + FFN_GELU_ERF, -1); // nn.GELU() defaults to exact erf-based GELU cb(cur, "after_projector", -1); } From ba692578547a61a296d9d9f41b8e890ff240e53c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 5 Apr 2026 20:48:55 +0200 Subject: [PATCH 6/7] update docs --- docs/multimodal.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/multimodal.md b/docs/multimodal.md index f2fc1510cfe..29b45b060d2 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -37,6 +37,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload > - PaddleOCR-VL: https://github.com/ggml-org/llama.cpp/pull/18825 > - GLM-OCR: https://github.com/ggml-org/llama.cpp/pull/19677 > - Deepseek-OCR: https://github.com/ggml-org/llama.cpp/pull/17400 +> - Dots.OCR: https://github.com/ggml-org/llama.cpp/pull/17575 ## Pre-quantized models From 4d6b04dfe28198116201154ac9b70589358a44f8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 5 Apr 2026 20:54:47 +0200 Subject: [PATCH 7/7] add gguf to test script --- tools/mtmd/tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index e081bde8750..6b751e04198 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -89,6 +89,7 @@ add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr +add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"