Add MTP speculative decoding via MTPCandidateGenerator#45618
Open
ArthurZucker wants to merge 2 commits intomainfrom
Open
Add MTP speculative decoding via MTPCandidateGenerator#45618ArthurZucker wants to merge 2 commits intomainfrom
ArthurZucker wants to merge 2 commits intomainfrom
Conversation
Wires MTP speculative decoding into `generate()` for DeepSeek-V3 and GLM-4 MoE
checkpoints that ship MTP modules (DeepSeek-V3 at `model.layers.61`,
GLM-4 MoE at `model.layers.46`/`.92` — previously hidden by
`_keys_to_ignore_on_load_unexpected`).
**Model side**
- New `num_nextn_predict_layers: int = 0` on `DeepseekV3Config` /
`Glm4MoeConfig` (propagates to downstream variants). Default keeps the
existing no-op behavior.
- `DeepseekV3MTPLayer` / `Glm4MoeMTPLayer` modules mirror the DeepSeek-V3
spec as implemented in vLLM: `enorm` + `hnorm` RMSNorms → concat → linear
`eh_proj(2H → H)` → a full decoder block → `shared_head (norm + lm_head)`.
- `DeepseekV3Model` / `Glm4MoeModel` extend `self.layers` past
`num_hidden_layers` with MTP modules; the base `forward` still iterates
only `self.layers[: num_hidden_layers]`. MTP is reached exclusively via a
new `model.forward_mtp(input_ids, previous_hidden_state, past_key_values,
position_ids, mtp_depth)` helper (lazily extends the KV cache for MTP
layer indices).
**Generation side**
- `GenerationConfig.use_mtp: bool = False` and a new
`GenerationMode.MTP_DECODING` routed from `get_generation_mode` whenever
the base mode is greedy or sample.
- `_mtp_decoding` in `generation/utils.py`: main forward → sample `x_{t+1}`
→ chain K MTP depths for draft tokens → single verify forward → reuses
`_speculative_sampling` for accept/reject → `past_key_values.crop`.
Batch size 1, dynamic cache; leaves `_assisted_decoding` untouched.
- `ContinuousBatchingManager` refuses `use_mtp=True` for now —
paged-attention slot reservation + per-request accept/reject is tracked
separately and will come as a follow-up.
**Tests**
- `tests/generation/test_mtp.py` covers: mode dispatch, greedy
token-for-token parity vs plain `_sample` for K=1/2/3 on both models,
`num_nextn_predict_layers=0` rejection, layer extension, base-forward
equivalence when MTP layers are added, `forward_mtp` shapes, and the
`generate_batch` `NotImplementedError`.
All 9 MTP tests pass locally. `make style` clean. `make fix-repo` clean
apart from the pre-existing `mlinter._using_rule_specs` env mismatch in
`check_modeling_rules_doc.py` / `check_modeling_structure.py` that also
fails on an unmodified checkout.
- MTP modules no longer live on DeepseekV3Model / Glm4MoeModel (no more layer-list extension or forward_mtp method); configs still expose num_nextn_predict_layers as metadata. - New transformers.generation.candidate_generators.MTPCandidateGenerator (nn.Module, implements CandidateGenerator) owns the MTP layers and introspects the base model's decoder + RMSNorm classes to build them. from_pretrained pulls MTP-specific keys out of the main checkpoint. - _mtp_decoding: use self.model(...) -> last_hidden_state + self.lm_head instead of forcing output_hidden_states=True. - Tests updated for the new architecture; all 9 MTP tests + 55 tests in the generation suite pass.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: deepseek_v3, glm4_moe, glm4v_moe, solar_open, youtu |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds
use_mtp=Truetogeneratefor DeepSeek-V3 / GLM-4 MoE. MTP modules live in a newgeneration.candidate_generators.MTPCandidateGenerator(loaded via itsfrom_pretrained) — the base model stays clean.