Skip to content

Add Multi-Token Prediction (MTP) support for Qwen3.5#45637

Closed
curnane-lab wants to merge 1 commit intohuggingface:mainfrom
curnane-lab:feature/qwen35-mtp-clean
Closed

Add Multi-Token Prediction (MTP) support for Qwen3.5#45637
curnane-lab wants to merge 1 commit intohuggingface:mainfrom
curnane-lab:feature/qwen35-mtp-clean

Conversation

@curnane-lab
Copy link
Copy Markdown

Add Multi-Token Prediction (MTP) support for Qwen3.5

This PR adds Multi-Token Prediction (MTP) architecture and loss computation for Qwen3.5 models, enabling multi-token prediction during training for improved efficiency.

Changes

New classes:

  • Qwen3_5MTPLayer: Single MTP transformer layer with attention and MLP
  • Qwen3_5MTP: Top-level MTP module with FC fusion, layers, and norm

New shared helper:

  • _compute_qwen35_mtp_loss(): Shared MTP loss computation function used by both CausalLM and VL models, eliminating code duplication

Modified models:

  • Qwen3_5ForCausalLM: Added MTP initialization and loss computation in forward pass
  • Qwen3_5ForConditionalGeneration: Added MTP initialization and loss computation in forward pass

Configuration:

  • Added mtp_num_hidden_layers (default: 0) and mtp_loss_weight (default: 0.0) to both Qwen3_5TextConfig and Qwen3_5Config
  • Removed mtp from _keys_to_ignore_on_load_unexpected in Qwen3_5ForCausalLM so MTP weights are properly loaded from checkpoints

Design decisions

  1. Shared loss function: The _compute_qwen35_mtp_loss() helper eliminates code duplication between the text-only and VL models. Both models delegate to this shared function with their respective embed_tokens and rotary_emb references.

  2. MTP loss stays in model files: Following the pattern of other auxiliary losses in transformers (e.g., MoE router losses), MTP loss is computed within the model's forward pass rather than in a separate trainer class.

  3. Backward compatible: With mtp_num_hidden_layers=0 (default), MTP is disabled and the models behave identically to before.

  4. Checkpoint alignment: The MTP module structure aligns with the Qwen3.5 checkpoint format:

    • mtp.pre_fc_norm_hidden.*
    • mtp.pre_fc_norm_embedding.*
    • mtp.fc.*
    • mtp.layers.N.*
    • mtp.norm.*

Testing

Tested with Qwen3.5-MTP model checkpoints to verify weight loading and loss computation.

Add MTP architecture and loss computation for Qwen3.5 models, enabling
multi-token prediction during training for improved efficiency.

Changes:
- Add Qwen3_5MTPLayer and Qwen3_5MTP module classes
- Add shared _compute_qwen35_mtp_loss() helper function
- Add MTP support to Qwen3_5ForCausalLM (text-only model)
- Add MTP support to Qwen3_5ForConditionalGeneration (VL model)
- Add mtp_num_hidden_layers and mtp_loss_weight config fields
- Remove mtp from _keys_to_ignore_on_load_unexpected in CausalLM
- Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants