From 3e0f392b75f2ce3a46ea1770644ca94b02a3303d Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 03:32:38 +0800 Subject: [PATCH 1/8] support aux loss in qwen3vlmoe --- .../models/qwen3_vl/modeling_qwen3_vl.py | 8 +- .../models/qwen3_vl/modular_qwen3_vl.py | 6 +- .../configuration_qwen3_vl_moe.py | 4 + .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 193 +++++++++++++++--- .../qwen3_vl_moe/modular_qwen3_vl_moe.py | 128 +++++++++++- 5 files changed, 296 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index d3bc3b6b044f..2b217dd0db52 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1105,6 +1105,7 @@ def get_placeholder_mask( @auto_docstring @can_return_tuple + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1235,8 +1236,6 @@ def forward( return Qwen3VLModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) @@ -1313,8 +1312,7 @@ def language_model(self): def visual(self): return self.model.visual - @can_return_tuple - @auto_docstring + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1372,8 +1370,6 @@ def forward( loss=loss, logits=logits, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, ) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 7f6a81919980..c6834bd27c58 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -1007,6 +1007,7 @@ def get_video_features( @auto_docstring @can_return_tuple + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1137,8 +1138,6 @@ def forward( return Qwen3VLModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) @@ -1151,6 +1150,7 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): config: Qwen3VLConfig _checkpoint_conversion_mapping = {} + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1208,8 +1208,6 @@ def forward( loss=loss, logits=logits, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, ) diff --git a/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py index c4a31e8f9f92..25358aa79bff 100644 --- a/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py @@ -80,6 +80,8 @@ class Qwen3VLMoeTextConfig(PretrainedConfig): Number of routed experts. norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the topk probabilities. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers @@ -178,6 +180,7 @@ def __init__( num_experts_per_tok=4, num_experts=60, norm_topk_prob=True, + router_aux_loss_coef=0.001, mlp_only_layers=None, rope_scaling=None, head_dim=None, @@ -213,6 +216,7 @@ def __init__( self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob + self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index e6550eee3590..23d3e2026515 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -971,6 +971,36 @@ def _deepstack_process( return hidden_states +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VLMoe causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + + @dataclass @auto_docstring( custom_intro=""" @@ -1217,6 +1247,7 @@ def get_placeholder_mask( @auto_docstring @can_return_tuple + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1347,39 +1378,90 @@ def forward( return Qwen3VLMoeModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Qwen3VLMoe causal language model (or autoregressive) outputs. - """ -) -class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput): +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[Cache] = None - hidden_states: Optional[tuple[torch.FloatTensor]] = None - attentions: Optional[tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin): @@ -1425,8 +1507,7 @@ def language_model(self): def visual(self): return self.model.visual - @can_return_tuple - @auto_docstring + @check_model_inputs def forward( self, input_ids: torch.LongTensor = None, @@ -1454,8 +1535,46 @@ def forward( The temporal, height and width of feature shape of each video in LLM. Example: - TODO: Add example - """ + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, @@ -1480,12 +1599,24 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + return Qwen3VLMoeCausalLMOutputWithPast( loss=loss, + aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, rope_deltas=outputs.rope_deltas, ) diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index ca59fbfd83da..48eba4633f93 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -14,21 +14,27 @@ # limitations under the License. """PyTorch Qwen3-VL-MOE model.""" +from typing import Optional, Union + import torch import torch.nn as nn from ...activations import ACT2FN +from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import PreTrainedModel -from ...utils import logging +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, logging from ..qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeDecoderLayer, Qwen3MoePreTrainedModel, Qwen3MoeRMSNorm, + load_balancing_loss_func, ) from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig from ..qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLCausalLMOutputWithPast, Qwen3VLForConditionalGeneration, Qwen3VLModel, Qwen3VLTextAttention, @@ -98,6 +104,8 @@ class Qwen3VLMoeTextConfig(PretrainedConfig): Number of routed experts. norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the topk probabilities. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers @@ -196,6 +204,7 @@ def __init__( num_experts_per_tok=4, num_experts=60, norm_topk_prob=True, + router_aux_loss_coef=0.001, mlp_only_layers=None, rope_scaling=None, head_dim=None, @@ -231,6 +240,7 @@ def __init__( self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob + self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) @@ -415,12 +425,126 @@ class Qwen3VLMoeTextModel(Qwen3VLTextModel): pass +class Qwen3VLMoeCausalLMOutputWithPast(Qwen3VLCausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + + class Qwen3VLMoeModel(Qwen3VLModel): pass class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): - pass + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + >>> model = Qwen3VLMoeForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image in short."}, + ], + } + ] + + >>> # Preparation for inference + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ) + >>> inputs = inputs.to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=128) + >>> generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." + ```""" + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.config.text_config.num_experts, + self.config.text_config.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return Qwen3VLMoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + ) __all__ = [ From 2a11935ecd17569bb5385455cece1fb15a1010aa Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 04:07:13 +0800 Subject: [PATCH 2/8] update qwen3vl processor test! --- .../qwen3_vl/test_processing_qwen3_vl.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/models/qwen3_vl/test_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_processing_qwen3_vl.py index d6d1938ccd57..eafda96ae0e5 100644 --- a/tests/models/qwen3_vl/test_processing_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_processing_qwen3_vl.py @@ -37,7 +37,6 @@ @require_vision @require_torch @require_torchvision -@unittest.skip("The checkpoint is not yet released") class Qwen3VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = Qwen3VLProcessor @@ -45,7 +44,7 @@ class Qwen3VLProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUpClass(cls): cls.tmpdirname = tempfile.mkdtemp() processor = Qwen3VLProcessor.from_pretrained( - "Qwen/Qwen3-VL-4B-Instruct", patch_size=4, max_pixels=56 * 56, min_pixels=28 * 28 + "Qwen/Qwen3-VL-235B-A22B-Instruct", patch_size=4, max_pixels=56 * 56, min_pixels=28 * 28 ) processor.save_pretrained(cls.tmpdirname) cls.image_token = processor.image_token @@ -139,21 +138,15 @@ def test_processor(self): processor(images=image_input) def test_model_input_names(self): - image_processor = self.get_image_processor() - tokenizer = self.get_tokenizer() - video_processor = self.get_video_processor() - - processor = Qwen3VLProcessor( - tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor - ) + processor = self.get_processor() - input_str = "lower newer" + text = self.prepare_text_inputs(modalities=["image", "video"]) image_input = self.prepare_image_inputs() video_inputs = self.prepare_video_inputs() + inputs_dict = {"text": text, "images": image_input, "videos": video_inputs} + inputs = processor(**inputs_dict, return_tensors="pt", do_sample_frames=False) - inputs = processor(text=input_str, images=image_input, videos=video_inputs, do_sample_frames=False) - - self.assertListEqual(list(inputs.keys()), processor.model_input_names) + self.assertSetEqual(set(inputs.keys()), set(processor.model_input_names)) @require_torch @require_av @@ -299,6 +292,9 @@ def test_apply_chat_template_video_frame_sampling(self): out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True) self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) + # for fast test, set the longest edge to 4096 + processor.video_processor.size['longest_edge'] = 8192 + # Add video URL for return dict and load with `num_frames` arg messages[0][0]["content"][0] = { "type": "video", @@ -311,9 +307,10 @@ def test_apply_chat_template_video_frame_sampling(self): tokenize=True, return_dict=True, num_frames=num_frames, + fps=None, # if pass num_frames, fps should be None ) self.assertTrue(self.videos_input_name in out_dict_with_video) - self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 360) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 256) # Load with `fps` arg fps = 1 @@ -325,7 +322,7 @@ def test_apply_chat_template_video_frame_sampling(self): fps=fps, ) self.assertTrue(self.videos_input_name in out_dict_with_video) - self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 900) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 224) # Load with `fps` and `num_frames` args, should raise an error with self.assertRaises(ValueError): @@ -346,7 +343,7 @@ def test_apply_chat_template_video_frame_sampling(self): return_dict=True, ) self.assertTrue(self.videos_input_name in out_dict_with_video) - self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 27000) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 224) # Load video as a list of frames (i.e. images). NOTE: each frame should have same size # because we assume they come from one video @@ -365,7 +362,7 @@ def test_apply_chat_template_video_frame_sampling(self): do_sample_frames=False, ) self.assertTrue(self.videos_input_name in out_dict_with_video) - self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160) + self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 216) def test_kwargs_overrides_custom_image_processor_kwargs(self): processor = self.get_processor() From ee58f2cc7bdc5012af33522dac6643f8609bf109 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 15:10:18 +0800 Subject: [PATCH 3/8] add integration tests for qwen3vl-30a3 --- .../qwen3_vl/test_processing_qwen3_vl.py | 4 +- .../test_modeling_qwen3_vl_moe.py | 282 ++++++++++++++++++ 2 files changed, 284 insertions(+), 2 deletions(-) diff --git a/tests/models/qwen3_vl/test_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_processing_qwen3_vl.py index eafda96ae0e5..14aa88ad4268 100644 --- a/tests/models/qwen3_vl/test_processing_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_processing_qwen3_vl.py @@ -293,7 +293,7 @@ def test_apply_chat_template_video_frame_sampling(self): self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) # for fast test, set the longest edge to 4096 - processor.video_processor.size['longest_edge'] = 8192 + processor.video_processor.size["longest_edge"] = 8192 # Add video URL for return dict and load with `num_frames` arg messages[0][0]["content"][0] = { @@ -307,7 +307,7 @@ def test_apply_chat_template_video_frame_sampling(self): tokenize=True, return_dict=True, num_frames=num_frames, - fps=None, # if pass num_frames, fps should be None + fps=None, # if pass num_frames, fps should be None ) self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 256) diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index e08e184e671a..cc82c48c1568 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -17,13 +17,18 @@ import unittest from transformers import ( + AutoProcessor, Qwen3VLMoeConfig, Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel, is_torch_available, ) from transformers.testing_utils import ( + cleanup, + require_flash_attn, require_torch, + require_torch_gpu, + slow, torch_device, ) @@ -296,3 +301,280 @@ def test_video_forward(self): video_grid_thw=video_grid_thw, ) self.assertIsNotNone(outputs) + + +@require_torch +@unittest.skip("The checkpoint is not yet released") +class Qwen3VLMoeIntegrationTest(unittest.TestCase): + def setUp(self): + cleanup(torch_device, gc_collect=True) + + self.processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct") + self.processor.tokenizer.padding_side = "left" + self.message = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "What kind of dog is this?"}, + ], + } + ] + self.message2 = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png", + }, + {"type": "text", "text": "What kind of dog is this?"}, + ], + } + ] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_small_model_integration_test(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + + inputs = self.processor.apply_chat_template( + self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ) + expected_input_ids = [151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655] # fmt: skip + assert expected_input_ids == inputs.input_ids[0].tolist()[:17] + + expected_pixel_slice = torch.tensor( + [ + [-0.0902, -0.0824, -0.0824], + [-0.2627, -0.2627, -0.2627], + [-0.0824, -0.0902, -0.0902], + [-0.0118, -0.0510, -0.1137], + [-0.5137, -0.5529, -0.6078], + [-0.6941, -0.6314, -0.5765], + ], + dtype=torch.float32, + device="cpu", + ) + assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3) + + # verify generation + inputs = inputs.to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and steppes" + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + batch_messages = [self.message] * 2 + inputs = self.processor.apply_chat_template( + batch_messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and montane regions", + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and montane regions" + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_with_video(self): + processor = AutoProcessor.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", max_image_size={"longest_edge": 50176} + ) + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype=torch.float16, device_map="auto" + ) + questions = ["How long is the video? Describe the it in short."] + video_urls = ["https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"] + messages = [ + [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": video_url, + }, + {"type": "text", "text": question}, + ], + } + ] + for question, video_url in zip(questions, video_urls) + ] + inputs = processor.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + EXPECTED_DECODED_TEXT = ["user\n<0.3 seconds><1.4 seconds><2.5 seconds><3.6 seconds><4.7 seconds><5.8 seconds>How long is the video? Describe the it in short.\nassistant\nThe video is 6 seconds long. It shows a man playing tennis on an indoor court. He is wearing a white shirt and black shorts. He"] # fmt: skip + + self.assertEqual( + processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_expand(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + inputs = self.processor.apply_chat_template( + self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (*Otocolobus manul*), also known", + "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the manul), a wild f" + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch_wo_image(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + message_wo_image = [ + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ] + batched_messages = [self.message, message_wo_image] + inputs = self.processor.apply_chat_template( + batched_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", + "user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such" + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch_different_resolutions(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto" + ) + batched_messages = [self.message, self.message2] + inputs = self.processor.apply_chat_template( + batched_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", + "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, the animals are not dogs. They are two cats.\n\nHere is a description of the animals in the image:\n\n- " + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_flashatt2(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", + dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + batched_messages = [self.message, self.message2] + inputs = self.processor.apply_chat_template( + batched_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions", + "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, there is no dog present. The animals in the picture are two cats.\n\nHere are some observations about the cats in the" + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_wo_image_flashatt2(self): + model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + "Qwen/Qwen3-VL-30B-A3B-Instruct", + dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + message_wo_image = [ + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ] + batched_messages = [self.message, message_wo_image] + inputs = self.processor.apply_chat_template( + batched_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + + EXPECTED_DECODED_TEXT = [ + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions", + "user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such" + ] # fmt: skip + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) From 883ca16b9edadb6959da171086fef36a8c88863b Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 19:06:37 +0800 Subject: [PATCH 4/8] remove duplicated decorator --- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 3 +-- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 1 - src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 3 +-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 2b217dd0db52..a6eec74f8009 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig @@ -1104,7 +1104,6 @@ def get_placeholder_mask( return special_image_mask, special_video_mask @auto_docstring - @can_return_tuple @check_model_inputs def forward( self, diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index c6834bd27c58..d55a470f0625 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -1006,7 +1006,6 @@ def get_video_features( return self.get_image_features(pixel_values_videos, video_grid_thw) @auto_docstring - @can_return_tuple @check_model_inputs def forward( self, diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 23d3e2026515..ccce10dde0ad 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig @@ -1246,7 +1246,6 @@ def get_placeholder_mask( return special_image_mask, special_video_mask @auto_docstring - @can_return_tuple @check_model_inputs def forward( self, From 433f56b73cbea91fe8d7150ef3060fe5132780dc Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 19:18:45 +0800 Subject: [PATCH 5/8] code clean --- tests/models/qwen3_vl/test_processing_qwen3_vl.py | 2 +- tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/qwen3_vl/test_processing_qwen3_vl.py b/tests/models/qwen3_vl/test_processing_qwen3_vl.py index 14aa88ad4268..9ce056a207ac 100644 --- a/tests/models/qwen3_vl/test_processing_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_processing_qwen3_vl.py @@ -292,7 +292,7 @@ def test_apply_chat_template_video_frame_sampling(self): out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True) self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) - # for fast test, set the longest edge to 4096 + # for fast test, set the longest edge to 8192 processor.video_processor.size["longest_edge"] = 8192 # Add video URL for return dict and load with `num_frames` arg diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index cc82c48c1568..63ea91e9d027 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -349,7 +349,7 @@ def test_small_model_integration_test(self): self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) expected_input_ids = [151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655] # fmt: skip - assert expected_input_ids == inputs.input_ids[0].tolist()[:17] + self.assertListEqual(expected_input_ids, inputs.input_ids[0].tolist()[:17]) expected_pixel_slice = torch.tensor( [ @@ -363,7 +363,7 @@ def test_small_model_integration_test(self): dtype=torch.float32, device="cpu", ) - assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3) + self.assertTrue(torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)) # verify generation inputs = inputs.to(torch_device) From 902af47d9a86cfe5e40654c9a819a923278f93a0 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 19:22:03 +0800 Subject: [PATCH 6/8] fix consistency --- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index d55a470f0625..9c479c6bdd23 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -33,7 +33,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import auto_docstring, is_torchdynamo_compiling, logging from ...utils.generic import check_model_inputs from ...video_utils import VideoInput from ..qwen2_5_vl.modeling_qwen2_5_vl import ( From f5dea1c694af8c994c769170813a8702332119ee Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 20:22:45 +0800 Subject: [PATCH 7/8] do not inherit from nn.Linear for better quantization --- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 32 +++++++------------ .../qwen3_vl_moe/modular_qwen3_vl_moe.py | 32 +++++++------------ 2 files changed, 22 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ccce10dde0ad..96b516a2c7bc 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -64,25 +64,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen3VLMoeTextRouter(nn.Linear): - def __init__(self, config): - super().__init__(config.hidden_size, config.num_experts, bias=False) - self.hidden_size = config.hidden_size - self.top_k = config.num_experts_per_tok - # since all the models use norm_topk_prob, we don't need to have a extra check for it - # self.norm_topk_prob = config.norm_topk_prob - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = super().forward(hidden_states) - routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) - return router_weights, router_logits, router_indices - - class Qwen3VLMoeTextExperts(nn.Module): def __init__(self, config): super().__init__() @@ -150,11 +131,20 @@ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_experts = config.num_experts - self.gate = Qwen3VLMoeTextRouter(config) + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = Qwen3VLMoeTextExperts(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - router_weights, router_logits, router_indices = self.gate(hidden_states) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) routed_out = self.experts(hidden_states, router_weights, router_indices) return routed_out diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index 48eba4633f93..f428f9de5d00 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -298,25 +298,6 @@ class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm): pass -class Qwen3VLMoeTextRouter(nn.Linear): - def __init__(self, config): - super().__init__(config.hidden_size, config.num_experts, bias=False) - self.hidden_size = config.hidden_size - self.top_k = config.num_experts_per_tok - # since all the models use norm_topk_prob, we don't need to have a extra check for it - # self.norm_topk_prob = config.norm_topk_prob - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_size) - router_logits = super().forward(hidden_states) - routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) - return router_weights, router_logits, router_indices - - class Qwen3VLMoeTextExperts(nn.Module): def __init__(self, config): super().__init__() @@ -384,11 +365,20 @@ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_experts = config.num_experts - self.gate = Qwen3VLMoeTextRouter(config) + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = Qwen3VLMoeTextExperts(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - router_weights, router_logits, router_indices = self.gate(hidden_states) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights) + hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size) routed_out = self.experts(hidden_states, router_weights, router_indices) return routed_out From ea6842384f14b2264adac029cdf3165cf2054278 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Thu, 2 Oct 2025 20:31:32 +0800 Subject: [PATCH 8/8] pass check --- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 3 +++ src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 96b516a2c7bc..7a9952786bfb 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -135,6 +135,9 @@ def __init__(self, config): self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = Qwen3VLMoeTextExperts(config) + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index f428f9de5d00..30dda5f99497 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -369,6 +369,9 @@ def __init__(self, config): self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = Qwen3VLMoeTextExperts(config) + # since all the models use norm_topk_prob, we don't need to have a extra check for it + # self.norm_topk_prob = config.norm_topk_prob + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size)