Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor
from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Config, ColQwen3Processor
from .qwen3_5 import BiQwen3_5, BiQwen3_5Processor, ColQwen3_5, ColQwen3_5Processor
from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor
2 changes: 1 addition & 1 deletion colpali_engine/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .biqwen3 import BiQwen3, BiQwen3Processor
from .colqwen3 import ColQwen3, ColQwen3Processor
from .colqwen3 import ColQwen3, ColQwen3Config, ColQwen3Processor
1 change: 1 addition & 0 deletions colpali_engine/models/qwen3/colqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .configuration_colqwen3 import ColQwen3Config
from .modeling_colqwen3 import ColQwen3
from .processing_colqwen3 import ColQwen3Processor
59 changes: 59 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/configuration_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from copy import deepcopy
from typing import Any

from transformers.configuration_utils import PretrainedConfig
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig


class ColQwen3Config(PretrainedConfig):
model_type = "colqwen3"
sub_configs: dict[str, Any] = {
"vision_config": Qwen3VLVisionConfig,
"text_config": Qwen3VLTextConfig,
}
is_composition = True

def __init__(
self,
vision_config: PretrainedConfig | dict[str, Any] | None = None,
text_config: PretrainedConfig | dict[str, Any] | None = None,
embed_dim: int = 320,
padding_side: str = "left",
initializer_range: float = 0.02,
dtype: str | None = None,
**kwargs,
):
if vision_config is None:
vision_config = Qwen3VLVisionConfig()
elif isinstance(vision_config, dict):
vision_config = Qwen3VLVisionConfig(**deepcopy(vision_config))

if text_config is None:
text_config = Qwen3VLTextConfig()
elif isinstance(text_config, dict):
text_config = Qwen3VLTextConfig(**deepcopy(text_config))

super().__init__(**kwargs)
self.vision_config = vision_config
self.text_config = text_config
self.embed_dim = embed_dim
self.padding_side = padding_side
self.initializer_range = initializer_range
self.dtype = dtype or getattr(self, "dtype", None)

def to_backbone_config(self) -> Qwen3VLConfig:
config = Qwen3VLConfig(
text_config=self.text_config.to_dict(),
vision_config=self.vision_config.to_dict(),
image_token_id=getattr(self, "image_token_id", 151655),
video_token_id=getattr(self, "video_token_id", 151656),
vision_start_token_id=getattr(self, "vision_start_token_id", 151652),
vision_end_token_id=getattr(self, "vision_end_token_id", 151653),
tie_word_embeddings=getattr(self, "tie_word_embeddings", False),
)

for attr in ("dtype", "use_cache", "pad_token_id", "bos_token_id", "eos_token_id"):
if hasattr(self, attr):
setattr(config, attr, getattr(self, attr))

return config
32 changes: 26 additions & 6 deletions colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from transformers.core_model_loading import WeightRenaming
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel

from .configuration_colqwen3 import ColQwen3Config


class ColQwen3(Qwen3VLModel):
"""
Expand All @@ -20,34 +22,40 @@ class ColQwen3(Qwen3VLModel):
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
config_class = ColQwen3Config
_keys_to_ignore_on_load_unexpected = [r"^vlm\.lm_head\.weight$", r"^lm_head\.weight$"]
_checkpoint_conversion_mapping = {
r"^base_model\.model\.custom_text_proj": "custom_text_proj",
r"^vlm\.model\.visual": "visual",
r"^vlm\.model\.language_model": "language_model",
r"^embedding_proj_layer": "custom_text_proj",
r"^model\.visual": "visual",
r"^model\.language_model": "language_model",
r"^model\.": "",
}

def __init__(
self,
config: Qwen3VLConfig,
config: Qwen3VLConfig | ColQwen3Config,
mask_non_image_embeddings: bool = False,
**kwargs,
):
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)
embed_dim = getattr(config, "embed_dim", 320)
padding_side = getattr(config, "padding_side", "left")
config = self._to_backbone_config(config)

super().__init__(config=config)

hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "text_config"):
hidden_size = self.config.text_config.hidden_size
hidden_size = self._get_text_hidden_size(self.config)
if hidden_size is None:
raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.")

self.dim = 320
self.dim = embed_dim
self.custom_text_proj = nn.Linear(hidden_size, self.dim)
self.padding_side = "left"
self.padding_side = padding_side
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

Expand All @@ -66,6 +74,18 @@ def from_pretrained(cls, *args, **kwargs):
key_mapping.update(getattr(cls, "_checkpoint_conversion_mapping", {}))
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

@staticmethod
def _to_backbone_config(config: Qwen3VLConfig | ColQwen3Config) -> Qwen3VLConfig:
if isinstance(config, ColQwen3Config):
return config.to_backbone_config()
return config

@staticmethod
def _get_text_hidden_size(config: Qwen3VLConfig) -> int | None:
if hasattr(config, "text_config") and hasattr(config.text_config, "hidden_size"):
return config.text_config.hidden_size
return getattr(config, "hidden_size", None)

def forward(self, *args, **kwargs) -> torch.Tensor:
# Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding
if "pixel_values" in kwargs:
Expand Down
37 changes: 37 additions & 0 deletions tests/models/qwen3/colqwen3/test_configuration_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformers.models.qwen3_vl import Qwen3VLConfig

from colpali_engine.models.qwen3.colqwen3 import ColQwen3, ColQwen3Config


def test_colqwen3_config_converts_to_qwen3_vl_config():
config = ColQwen3Config(
text_config={"hidden_size": 64, "num_hidden_layers": 3},
vision_config={"hidden_size": 32, "depth": 2},
embed_dim=24,
image_token_id=11,
video_token_id=12,
vision_start_token_id=13,
vision_end_token_id=14,
)

backbone_config = config.to_backbone_config()

assert isinstance(backbone_config, Qwen3VLConfig)
assert backbone_config.text_config.hidden_size == 64
assert backbone_config.text_config.num_hidden_layers == 3
assert backbone_config.vision_config.hidden_size == 32
assert backbone_config.vision_config.depth == 2
assert backbone_config.image_token_id == 11
assert backbone_config.video_token_id == 12
assert backbone_config.vision_start_token_id == 13
assert backbone_config.vision_end_token_id == 14


def test_colqwen3_prefers_text_hidden_size_over_top_level_hidden_size():
config = Qwen3VLConfig(
text_config={"hidden_size": 64},
vision_config={"hidden_size": 32},
)
config.hidden_size = 32

assert ColQwen3._get_text_hidden_size(config) == 64
47 changes: 37 additions & 10 deletions tests/models/test_checkpoint_key_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,52 @@ def test_colqwen2_5_omni_conversion_mapping_is_registered_for_adapter_loading():
assert key == "custom_text_proj.lora_B.default.weight"


def test_colqwen3_adapter_key_mapping_remaps_custom_text_proj_and_layers():
def test_colqwen3_checkpoint_key_mapping_remaps_adapter_and_tomoro_checkpoint_keys():
assert (
_apply_mapping(
"base_model.model.custom_text_proj.lora_A.default.weight",
ColQwen3._checkpoint_conversion_mapping,
)
== "custom_text_proj.lora_A.default.weight"
)
assert (
_apply_mapping(
"vlm.model.language_model.layers.17.self_attn.v_proj.weight",
ColQwen3._checkpoint_conversion_mapping,
)
== "language_model.layers.17.self_attn.v_proj.weight"
)
assert (
_apply_mapping(
"vlm.model.visual.blocks.3.attn.proj.weight",
ColQwen3._checkpoint_conversion_mapping,
)
== "visual.blocks.3.attn.proj.weight"
)
assert (
_apply_mapping("embedding_proj_layer.bias", ColQwen3._checkpoint_conversion_mapping)
== "custom_text_proj.bias"
)


def test_colqwen3_ignores_expected_unexpected_lm_head_weight():
assert r"^vlm\.lm_head\.weight$" in ColQwen3._keys_to_ignore_on_load_unexpected
assert r"^lm_head\.weight$" in ColQwen3._keys_to_ignore_on_load_unexpected


def test_colqwen3_conversion_mapping_is_registered_for_adapter_loading():
def test_colqwen3_conversion_mapping_is_registered_for_adapter_and_tomoro_checkpoint_loading():
mapping = get_checkpoint_conversion_mapping("qwen3_vl")
assert mapping is not None

key = "base_model.model.custom_text_proj.lora_B.default.weight"
for renaming in mapping:
if not hasattr(renaming, "source_patterns") or not hasattr(renaming, "target_patterns"):
continue
for pattern, replacement in zip(renaming.source_patterns, renaming.target_patterns):
key = re.sub(pattern, replacement, key)

assert key == "custom_text_proj.lora_B.default.weight"
for key, expected in (
("base_model.model.custom_text_proj.lora_B.default.weight", "custom_text_proj.lora_B.default.weight"),
("vlm.model.language_model.layers.3.mlp.down_proj.weight", "language_model.layers.3.mlp.down_proj.weight"),
):
remapped_key = key
for renaming in mapping:
if not hasattr(renaming, "source_patterns") or not hasattr(renaming, "target_patterns"):
continue
for pattern, replacement in zip(renaming.source_patterns, renaming.target_patterns):
remapped_key = re.sub(pattern, replacement, remapped_key)

assert remapped_key == expected
Loading