Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(source_patterns=r"vlm.model", target_patterns="vlm"),
WeightRenaming(source_patterns=r"vlm(?!\.(language_model|visual))", target_patterns="vlm.language_model"),
],
"gemma3n_text": [
WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
],
"timm_wrapper": [
# Simply add the prefix `timm_model`. Similar to `base_model_prefix` but also removes prefix
# when saving. TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
Expand Down Expand Up @@ -152,9 +149,6 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming("attention_layer_norm", "input_layernorm"),
WeightRenaming("feedforward_layer_norm", "post_attention_layernorm"),
],
"qwen3_5_text": [
WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
],
"sam3_tracker": [
WeightRenaming(
source_patterns=r"detector_model.vision_encoder.backbone.", target_patterns="vision_encoder.backbone."
Expand Down Expand Up @@ -518,8 +512,7 @@ def _build_checkpoint_conversion_mapping():
),
]

mapping["qwen3_5_moe_text"] = mapping["qwen3_5_text"].copy()
mapping["qwen3_5_moe_text"] += mapping["qwen2_moe"].copy()
mapping["qwen3_5_moe_text"] = mapping["qwen2_moe"].copy()

mapping["cohere_asr"] = [
WeightRenaming(r"encoder\.pre_encode\.conv\.", r"encoder.subsampling.layers."),
Expand Down Expand Up @@ -612,6 +605,17 @@ def get_model_conversion_mapping(
# Load models with explicit, user-provided key mapping
if key_mapping is not None:
weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]
elif any(
allowed_name in class_name.__name__.lower()
for class_name in model.__class__.__mro__[:-1]
for allowed_name in ["qwen3_5", "gemma3n"]
):
# TODO: these are used only for VLMs which sometimes are loaded as LLMs
# prob can be fixed as we did with `config_class`, all at once for VLM-LLMs
weight_conversions = [
WeightRenaming(source_patterns=k, target_patterns=v)
for k, v in model._checkpoint_conversion_mapping.items()
]

# Model have several `PreTrainedModel` within with the same model type
# For ex: XForConditionalGeneration -> XModel. We don't want to apply the same
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma3n/modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Gemma3nTextConfig
_checkpoint_conversion_mapping = {"model.language_model": "model"}

def __init__(self, config: Gemma3nTextConfig):
super().__init__(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma3n/modular_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,7 +1926,7 @@ def forward(

@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
class Gemma3nForCausalLM(Gemma3ForCausalLM):
pass
_checkpoint_conversion_mapping = {"model.language_model": "model"}


class Gemma3nMultimodalEmbedder(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen3_5/modeling_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Qwen3_5TextConfig
_checkpoint_conversion_mapping = {"model.language_model": "model"}
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]

def __init__(self, config):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen3_5/modular_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def forward(

class Qwen3_5ForCausalLM(Qwen3ForCausalLM):
config: Qwen3_5TextConfig
_checkpoint_conversion_mapping = {"model.language_model": "model"}
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,7 @@ class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Qwen3_5MoeTextConfig
_checkpoint_conversion_mapping = {"model.language_model": "model"}
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]

def __init__(self, config):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class Qwen3_5MoeModel(Qwen3_5Model):

class Qwen3_5MoeForCausalLM(Qwen3NextForCausalLM):
config: Qwen3_5MoeTextConfig
_checkpoint_conversion_mapping = {"model.language_model": "model"}
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]

def __init__(self, config):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/gemma3n/test_modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,6 @@ def test_get_audio_features_attentions(self, return_dict: bool | None):
def test_generate_with_quant_cache(self):
pass

@unittest.skip(
"Conversion only for the `CausalLM` loading from saved `ConditionalLM`, doesn't apply to simple VLM"
)
def test_reverse_loading_mapping(self, check_keys_were_modified=True):
pass

def _check_hidden_states_for_generate(
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/qwen3_5/test_modeling_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,6 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip(
"Conversion only for the `CausalLM` loading from saved `ConditionalLM`, doesn't apply to simple VLM"
)
def test_reverse_loading_mapping(self, check_keys_were_modified=True):
pass

def _get_conv_state_shape(self, batch_size: int, config):
num_v_heads = config.linear_num_value_heads
num_k_heads = config.linear_num_key_heads
Expand Down
6 changes: 0 additions & 6 deletions tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,6 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip(
"Conversion only for the `CausalLM` loading from saved `ConditionalLM`, doesn't apply to simple VLM"
)
def test_reverse_loading_mapping(self, check_keys_were_modified=True):
pass

def _get_conv_state_shape(self, batch_size: int, config):
num_v_heads = config.linear_num_value_heads
num_k_heads = config.linear_num_key_heads
Expand Down
Loading