From f5d67cb4ef5873b915f5fe776250f98b06375309 Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Sun, 22 Mar 2026 20:00:11 +0800 Subject: [PATCH 1/2] Fix Tomoro ColQwen3 checkpoint loading --- colpali_engine/models/__init__.py | 2 +- colpali_engine/models/qwen3/__init__.py | 2 +- .../models/qwen3/colqwen3/__init__.py | 1 + .../qwen3/colqwen3/configuration_colqwen3.py | 59 +++++++++++++++++++ .../qwen3/colqwen3/modeling_colqwen3.py | 32 ++++++++-- .../colqwen3/test_configuration_colqwen3.py | 37 ++++++++++++ tests/models/test_checkpoint_key_mappings.py | 44 ++++++++++---- 7 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 colpali_engine/models/qwen3/colqwen3/configuration_colqwen3.py create mode 100644 tests/models/qwen3/colqwen3/test_configuration_colqwen3.py diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index e4845bf7..2944bd27 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -4,6 +4,6 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor -from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor +from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Config, ColQwen3Processor from .qwen3_5 import BiQwen3_5, BiQwen3_5Processor, ColQwen3_5, ColQwen3_5Processor from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen3/__init__.py b/colpali_engine/models/qwen3/__init__.py index efcee26f..8958c447 100644 --- a/colpali_engine/models/qwen3/__init__.py +++ b/colpali_engine/models/qwen3/__init__.py @@ -1,2 +1,2 @@ from .biqwen3 import BiQwen3, BiQwen3Processor -from .colqwen3 import ColQwen3, ColQwen3Processor +from .colqwen3 import ColQwen3, ColQwen3Config, ColQwen3Processor diff --git a/colpali_engine/models/qwen3/colqwen3/__init__.py b/colpali_engine/models/qwen3/colqwen3/__init__.py index 6369cb69..7fe3ecab 100644 --- a/colpali_engine/models/qwen3/colqwen3/__init__.py +++ b/colpali_engine/models/qwen3/colqwen3/__init__.py @@ -1,2 +1,3 @@ +from .configuration_colqwen3 import ColQwen3Config from .modeling_colqwen3 import ColQwen3 from .processing_colqwen3 import ColQwen3Processor diff --git a/colpali_engine/models/qwen3/colqwen3/configuration_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/configuration_colqwen3.py new file mode 100644 index 00000000..3b73c4ae --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/configuration_colqwen3.py @@ -0,0 +1,59 @@ +from copy import deepcopy +from typing import Any + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig + + +class ColQwen3Config(PretrainedConfig): + model_type = "colqwen3" + sub_configs: dict[str, Any] = { + "vision_config": Qwen3VLVisionConfig, + "text_config": Qwen3VLTextConfig, + } + is_composition = True + + def __init__( + self, + vision_config: PretrainedConfig | dict[str, Any] | None = None, + text_config: PretrainedConfig | dict[str, Any] | None = None, + embed_dim: int = 320, + padding_side: str = "left", + initializer_range: float = 0.02, + dtype: str | None = None, + **kwargs, + ): + if vision_config is None: + vision_config = Qwen3VLVisionConfig() + elif isinstance(vision_config, dict): + vision_config = Qwen3VLVisionConfig(**deepcopy(vision_config)) + + if text_config is None: + text_config = Qwen3VLTextConfig() + elif isinstance(text_config, dict): + text_config = Qwen3VLTextConfig(**deepcopy(text_config)) + + super().__init__(**kwargs) + self.vision_config = vision_config + self.text_config = text_config + self.embed_dim = embed_dim + self.padding_side = padding_side + self.initializer_range = initializer_range + self.dtype = dtype or getattr(self, "dtype", None) + + def to_backbone_config(self) -> Qwen3VLConfig: + config = Qwen3VLConfig( + text_config=self.text_config.to_dict(), + vision_config=self.vision_config.to_dict(), + image_token_id=getattr(self, "image_token_id", 151655), + video_token_id=getattr(self, "video_token_id", 151656), + vision_start_token_id=getattr(self, "vision_start_token_id", 151652), + vision_end_token_id=getattr(self, "vision_end_token_id", 151653), + tie_word_embeddings=getattr(self, "tie_word_embeddings", False), + ) + + for attr in ("dtype", "use_cache", "pad_token_id", "bos_token_id", "eos_token_id"): + if hasattr(self, attr): + setattr(config, attr, getattr(self, attr)) + + return config diff --git a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py index b11357eb..9e224e9c 100644 --- a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py +++ b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py @@ -6,6 +6,8 @@ from transformers.core_model_loading import WeightRenaming from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel +from .configuration_colqwen3 import ColQwen3Config + class ColQwen3(Qwen3VLModel): """ @@ -20,8 +22,13 @@ class ColQwen3(Qwen3VLModel): """ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + config_class = ColQwen3Config + _keys_to_ignore_on_load_unexpected = [r"^vlm\.lm_head\.weight$", r"^lm_head\.weight$"] _checkpoint_conversion_mapping = { r"^base_model\.model\.custom_text_proj": "custom_text_proj", + r"^vlm\.model\.visual": "visual", + r"^vlm\.model\.language_model": "language_model", + r"^embedding_proj_layer": "custom_text_proj", r"^model\.visual": "visual", r"^model\.language_model": "language_model", r"^model\.": "", @@ -29,25 +36,26 @@ class ColQwen3(Qwen3VLModel): def __init__( self, - config: Qwen3VLConfig, + config: Qwen3VLConfig | ColQwen3Config, mask_non_image_embeddings: bool = False, **kwargs, ): dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) attn_impl = kwargs.pop("attn_implementation", None) use_cache = kwargs.pop("use_cache", None) + embed_dim = getattr(config, "embed_dim", 320) + padding_side = getattr(config, "padding_side", "left") + config = self._to_backbone_config(config) super().__init__(config=config) - hidden_size = getattr(self.config, "hidden_size", None) - if hidden_size is None and hasattr(self.config, "text_config"): - hidden_size = self.config.text_config.hidden_size + hidden_size = self._get_text_hidden_size(self.config) if hidden_size is None: raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.") - self.dim = 320 + self.dim = embed_dim self.custom_text_proj = nn.Linear(hidden_size, self.dim) - self.padding_side = "left" + self.padding_side = padding_side self.mask_non_image_embeddings = mask_non_image_embeddings self.post_init() @@ -66,6 +74,18 @@ def from_pretrained(cls, *args, **kwargs): key_mapping.update(getattr(cls, "_checkpoint_conversion_mapping", {})) return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + @staticmethod + def _to_backbone_config(config: Qwen3VLConfig | ColQwen3Config) -> Qwen3VLConfig: + if isinstance(config, ColQwen3Config): + return config.to_backbone_config() + return config + + @staticmethod + def _get_text_hidden_size(config: Qwen3VLConfig) -> int | None: + if hasattr(config, "text_config") and hasattr(config.text_config, "hidden_size"): + return config.text_config.hidden_size + return getattr(config, "hidden_size", None) + def forward(self, *args, **kwargs) -> torch.Tensor: # Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding if "pixel_values" in kwargs: diff --git a/tests/models/qwen3/colqwen3/test_configuration_colqwen3.py b/tests/models/qwen3/colqwen3/test_configuration_colqwen3.py new file mode 100644 index 00000000..66951282 --- /dev/null +++ b/tests/models/qwen3/colqwen3/test_configuration_colqwen3.py @@ -0,0 +1,37 @@ +from transformers.models.qwen3_vl import Qwen3VLConfig + +from colpali_engine.models.qwen3.colqwen3 import ColQwen3, ColQwen3Config + + +def test_colqwen3_config_converts_to_qwen3_vl_config(): + config = ColQwen3Config( + text_config={"hidden_size": 64, "num_hidden_layers": 3}, + vision_config={"hidden_size": 32, "depth": 2}, + embed_dim=24, + image_token_id=11, + video_token_id=12, + vision_start_token_id=13, + vision_end_token_id=14, + ) + + backbone_config = config.to_backbone_config() + + assert isinstance(backbone_config, Qwen3VLConfig) + assert backbone_config.text_config.hidden_size == 64 + assert backbone_config.text_config.num_hidden_layers == 3 + assert backbone_config.vision_config.hidden_size == 32 + assert backbone_config.vision_config.depth == 2 + assert backbone_config.image_token_id == 11 + assert backbone_config.video_token_id == 12 + assert backbone_config.vision_start_token_id == 13 + assert backbone_config.vision_end_token_id == 14 + + +def test_colqwen3_prefers_text_hidden_size_over_top_level_hidden_size(): + config = Qwen3VLConfig( + text_config={"hidden_size": 64}, + vision_config={"hidden_size": 32}, + ) + config.hidden_size = 32 + + assert ColQwen3._get_text_hidden_size(config) == 64 diff --git a/tests/models/test_checkpoint_key_mappings.py b/tests/models/test_checkpoint_key_mappings.py index c5cf3194..b1a4eef8 100644 --- a/tests/models/test_checkpoint_key_mappings.py +++ b/tests/models/test_checkpoint_key_mappings.py @@ -126,7 +126,7 @@ def test_colqwen2_5_omni_conversion_mapping_is_registered_for_adapter_loading(): assert key == "custom_text_proj.lora_B.default.weight" -def test_colqwen3_adapter_key_mapping_remaps_custom_text_proj_and_layers(): +def test_colqwen3_checkpoint_key_mapping_remaps_adapter_and_tomoro_checkpoint_keys(): assert ( _apply_mapping( "base_model.model.custom_text_proj.lora_A.default.weight", @@ -134,17 +134,41 @@ def test_colqwen3_adapter_key_mapping_remaps_custom_text_proj_and_layers(): ) == "custom_text_proj.lora_A.default.weight" ) + assert ( + _apply_mapping( + "vlm.model.language_model.layers.17.self_attn.v_proj.weight", + ColQwen3._checkpoint_conversion_mapping, + ) + == "language_model.layers.17.self_attn.v_proj.weight" + ) + assert ( + _apply_mapping( + "vlm.model.visual.blocks.3.attn.proj.weight", + ColQwen3._checkpoint_conversion_mapping, + ) + == "visual.blocks.3.attn.proj.weight" + ) + assert _apply_mapping("embedding_proj_layer.bias", ColQwen3._checkpoint_conversion_mapping) == "custom_text_proj.bias" -def test_colqwen3_conversion_mapping_is_registered_for_adapter_loading(): +def test_colqwen3_ignores_expected_unexpected_lm_head_weight(): + assert r"^vlm\.lm_head\.weight$" in ColQwen3._keys_to_ignore_on_load_unexpected + assert r"^lm_head\.weight$" in ColQwen3._keys_to_ignore_on_load_unexpected + + +def test_colqwen3_conversion_mapping_is_registered_for_adapter_and_tomoro_checkpoint_loading(): mapping = get_checkpoint_conversion_mapping("qwen3_vl") assert mapping is not None - key = "base_model.model.custom_text_proj.lora_B.default.weight" - for renaming in mapping: - if not hasattr(renaming, "source_patterns") or not hasattr(renaming, "target_patterns"): - continue - for pattern, replacement in zip(renaming.source_patterns, renaming.target_patterns): - key = re.sub(pattern, replacement, key) - - assert key == "custom_text_proj.lora_B.default.weight" + for key, expected in ( + ("base_model.model.custom_text_proj.lora_B.default.weight", "custom_text_proj.lora_B.default.weight"), + ("vlm.model.language_model.layers.3.mlp.down_proj.weight", "language_model.layers.3.mlp.down_proj.weight"), + ): + remapped_key = key + for renaming in mapping: + if not hasattr(renaming, "source_patterns") or not hasattr(renaming, "target_patterns"): + continue + for pattern, replacement in zip(renaming.source_patterns, renaming.target_patterns): + remapped_key = re.sub(pattern, replacement, remapped_key) + + assert remapped_key == expected From bd72eed2b104af86248283a95963faccf5c69bee Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Sun, 22 Mar 2026 21:01:46 +0800 Subject: [PATCH 2/2] Fix Ruff line length in key mapping test --- tests/models/test_checkpoint_key_mappings.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_checkpoint_key_mappings.py b/tests/models/test_checkpoint_key_mappings.py index b1a4eef8..282984f1 100644 --- a/tests/models/test_checkpoint_key_mappings.py +++ b/tests/models/test_checkpoint_key_mappings.py @@ -148,7 +148,10 @@ def test_colqwen3_checkpoint_key_mapping_remaps_adapter_and_tomoro_checkpoint_ke ) == "visual.blocks.3.attn.proj.weight" ) - assert _apply_mapping("embedding_proj_layer.bias", ColQwen3._checkpoint_conversion_mapping) == "custom_text_proj.bias" + assert ( + _apply_mapping("embedding_proj_layer.bias", ColQwen3._checkpoint_conversion_mapping) + == "custom_text_proj.bias" + ) def test_colqwen3_ignores_expected_unexpected_lm_head_weight():