Add Multi-Token Prediction (MTP) inference support#45617
Closed
ArthurZucker wants to merge 1 commit intohuggingface:mainfrom
Closed
Add Multi-Token Prediction (MTP) inference support#45617ArthurZucker wants to merge 1 commit intohuggingface:mainfrom
ArthurZucker wants to merge 1 commit intohuggingface:mainfrom
Conversation
Wires MTP speculative decoding into `generate()` for DeepSeek-V3 and GLM-4 MoE
checkpoints that ship MTP modules (DeepSeek-V3 at `model.layers.61`,
GLM-4 MoE at `model.layers.46`/`.92` — previously hidden by
`_keys_to_ignore_on_load_unexpected`).
**Model side**
- New `num_nextn_predict_layers: int = 0` on `DeepseekV3Config` /
`Glm4MoeConfig` (propagates to downstream variants). Default keeps the
existing no-op behavior.
- `DeepseekV3MTPLayer` / `Glm4MoeMTPLayer` modules mirror the DeepSeek-V3
spec as implemented in vLLM: `enorm` + `hnorm` RMSNorms → concat → linear
`eh_proj(2H → H)` → a full decoder block → `shared_head (norm + lm_head)`.
- `DeepseekV3Model` / `Glm4MoeModel` extend `self.layers` past
`num_hidden_layers` with MTP modules; the base `forward` still iterates
only `self.layers[: num_hidden_layers]`. MTP is reached exclusively via a
new `model.forward_mtp(input_ids, previous_hidden_state, past_key_values,
position_ids, mtp_depth)` helper (lazily extends the KV cache for MTP
layer indices).
**Generation side**
- `GenerationConfig.use_mtp: bool = False` and a new
`GenerationMode.MTP_DECODING` routed from `get_generation_mode` whenever
the base mode is greedy or sample.
- `_mtp_decoding` in `generation/utils.py`: main forward → sample `x_{t+1}`
→ chain K MTP depths for draft tokens → single verify forward → reuses
`_speculative_sampling` for accept/reject → `past_key_values.crop`.
Batch size 1, dynamic cache; leaves `_assisted_decoding` untouched.
- `ContinuousBatchingManager` refuses `use_mtp=True` for now —
paged-attention slot reservation + per-request accept/reject is tracked
separately and will come as a follow-up.
**Tests**
- `tests/generation/test_mtp.py` covers: mode dispatch, greedy
token-for-token parity vs plain `_sample` for K=1/2/3 on both models,
`num_nextn_predict_layers=0` rejection, layer extension, base-forward
equivalence when MTP layers are added, `forward_mtp` shapes, and the
`generate_batch` `NotImplementedError`.
All 9 MTP tests pass locally. `make style` clean. `make fix-repo` clean
apart from the pre-existing `mlinter._using_rule_specs` env mismatch in
`check_modeling_rules_doc.py` / `check_modeling_structure.py` that also
fails on an unmodified checkout.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: deepseek_v3, glm4_moe, glm4_moe_lite, glm4v_moe, glm_moe_dsa, longcat_flash, solar_open, youtu |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Collaborator
Author
|
Superseded by a cleaner refactor (MTP moved out of the model into MTPCandidateGenerator). Reopening from the origin branch. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
generate()for DeepSeek-V3 and GLM-4 MoE, loading the MTP modules that ship in the original checkpoints (previously ignored via_keys_to_ignore_on_load_unexpected).GenerationConfig.use_mtp=Trueopt-in (defaultFalse; no behavior change when off). Routes to a newGenerationMode.MTP_DECODING/_mtp_decodingthat uses the model's own MTP heads as the draft, verifies in a single forward, reuses the existing_speculative_samplinghelper.MTPLayer(RMSNorm + concat-and-project + decoder block +shared_head) mirroring vLLM'sdeepseek_mtp.py. Baseforwardkeeps iterating onlyself.layers[:num_hidden_layers]; MTP is reached via a newmodel.forward_mtp(...)helper._assisted_decoding/candidate_generator.pyare untouched.Scope
generate(..., use_mtp=True)is supported for DeepSeek-V3 / GLM-4 MoE atbatch_size=1with dynamic cache.generate_batchwithuse_mtp=TrueraisesNotImplementedErrorfor now — continuous batching needs paged-cache slot reservation (K+1 per MTP request) plus per-request accept/reject in_forward_process_and_sample. That will come as a follow-up PR so it gets a clean review on its own.Design decisions
_mtp_decoding, not routed through_assisted_decoding.use_mtp=True; depth =config.num_nextn_predict_layers.self.layersholds MTP modules (matches the checkpoint key schememodel.layers.{num_base + k}.*).Files
generation/configuration_utils.py(newuse_mtpflag + mode),generation/utils.py(new_mtp_decoding, mode registration, validation),generation/continuous_batching/continuous_api.py(explicit refusal for now).deepseek_v3/andglm4_moe/modular/config/modeling. Downstream variants (glm4_moe_lite,glm4v_moe,glm_moe_dsa,longcat_flash,solar_open,youtu) got regenerated with the new field propagated through the modular chain.tests/generation/test_mtp.py.Test plan
pytest tests/generation/test_mtp.py— 9/9 passing:use_mtp=Trueroutes toGenerationMode.MTP_DECODING._sampletoken-for-token for K=1/2/3 on DeepSeek-V3 and GLM-4 MoE (random-init, so drafts usually miss → verified bonus-token fallback path).use_mtp=Truewithnum_nextn_predict_layers == 0raises a clearValueError.model.forward_mtpproduces(hidden, logits)of the expected shapes.ContinuousBatchingManager(..., GenerationConfig(use_mtp=True))raisesNotImplementedError.pytest tests/generation/test_candidate_generator.py tests/generation/test_configuration_utils.py— unchanged / green (regression check).glm4_moe_lite,glm4v_moe,glm_moe_dsa,longcat_flash,solar_open,youtu).make styleclean.make fix-repoclean apart from a pre-existingmlinter._using_rule_specs/check_modeling_structure.py --rules-tomltoolchain mismatch that also fails on an unmodified checkout — unrelated to this PR.Follow-ups
generate_batch+ MTP (paged-cache reservation, per-request accept/reject).deepseek-ai/DeepSeek-V3once GPU capacity is available.