Skip to content

Add Multi-Token Prediction (MTP) inference support#45617

Closed
ArthurZucker wants to merge 1 commit intohuggingface:mainfrom
ArthurZucker:mtp-inference
Closed

Add Multi-Token Prediction (MTP) inference support#45617
ArthurZucker wants to merge 1 commit intohuggingface:mainfrom
ArthurZucker:mtp-inference

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Summary

  • Wires MTP speculative decoding into generate() for DeepSeek-V3 and GLM-4 MoE, loading the MTP modules that ship in the original checkpoints (previously ignored via _keys_to_ignore_on_load_unexpected).
  • New GenerationConfig.use_mtp=True opt-in (default False; no behavior change when off). Routes to a new GenerationMode.MTP_DECODING / _mtp_decoding that uses the model's own MTP heads as the draft, verifies in a single forward, reuses the existing _speculative_sampling helper.
  • Adds a tiny MTPLayer (RMSNorm + concat-and-project + decoder block + shared_head) mirroring vLLM's deepseek_mtp.py. Base forward keeps iterating only self.layers[:num_hidden_layers]; MTP is reached via a new model.forward_mtp(...) helper.
  • _assisted_decoding / candidate_generator.py are untouched.

Scope

  • generate(..., use_mtp=True) is supported for DeepSeek-V3 / GLM-4 MoE at batch_size=1 with dynamic cache.
  • generate_batch with use_mtp=True raises NotImplementedError for now — continuous batching needs paged-cache slot reservation (K+1 per MTP request) plus per-request accept/reject in _forward_process_and_sample. That will come as a follow-up PR so it gets a clean review on its own.
  • Training-time MTP loss is explicitly out of scope.

Design decisions

Draft Model's own MTP heads (shared weights + shared KV cache).
Path Dedicated _mtp_decoding, not routed through _assisted_decoding.
Activation Explicit use_mtp=True; depth = config.num_nextn_predict_layers.
Layout Tail of self.layers holds MTP modules (matches the checkpoint key scheme model.layers.{num_base + k}.*).

Files

  • Generation: generation/configuration_utils.py (new use_mtp flag + mode), generation/utils.py (new _mtp_decoding, mode registration, validation), generation/continuous_batching/continuous_api.py (explicit refusal for now).
  • Models: deepseek_v3/ and glm4_moe/ modular/config/modeling. Downstream variants (glm4_moe_lite, glm4v_moe, glm_moe_dsa, longcat_flash, solar_open, youtu) got regenerated with the new field propagated through the modular chain.
  • Tests: new tests/generation/test_mtp.py.

Test plan

  • pytest tests/generation/test_mtp.py — 9/9 passing:
    • use_mtp=True routes to GenerationMode.MTP_DECODING.
    • Greedy output matches plain _sample token-for-token for K=1/2/3 on DeepSeek-V3 and GLM-4 MoE (random-init, so drafts usually miss → verified bonus-token fallback path).
    • use_mtp=True with num_nextn_predict_layers == 0 raises a clear ValueError.
    • Base forward output is unchanged when MTP layers are added (shared-weight copy + compare).
    • model.forward_mtp produces (hidden, logits) of the expected shapes.
    • ContinuousBatchingManager(..., GenerationConfig(use_mtp=True)) raises NotImplementedError.
  • pytest tests/generation/test_candidate_generator.py tests/generation/test_configuration_utils.py — unchanged / green (regression check).
  • Import smoke-test on every modular-regenerated downstream model (glm4_moe_lite, glm4v_moe, glm_moe_dsa, longcat_flash, solar_open, youtu).
  • make style clean.
  • make fix-repo clean apart from a pre-existing mlinter._using_rule_specs / check_modeling_structure.py --rules-toml toolchain mismatch that also fails on an unmodified checkout — unrelated to this PR.

Follow-ups

  • generate_batch + MTP (paged-cache reservation, per-request accept/reject).
  • Prefill through the MTP layers during the initial forward, to seed their KV caches for better draft quality (currently each step's MTP self-attn only sees the current query token; degraded acceptance vs vLLM but correct).
  • Real-weight end-to-end test on deepseek-ai/DeepSeek-V3 once GPU capacity is available.

Wires MTP speculative decoding into `generate()` for DeepSeek-V3 and GLM-4 MoE
checkpoints that ship MTP modules (DeepSeek-V3 at `model.layers.61`,
GLM-4 MoE at `model.layers.46`/`.92` — previously hidden by
`_keys_to_ignore_on_load_unexpected`).

**Model side**
- New `num_nextn_predict_layers: int = 0` on `DeepseekV3Config` /
  `Glm4MoeConfig` (propagates to downstream variants). Default keeps the
  existing no-op behavior.
- `DeepseekV3MTPLayer` / `Glm4MoeMTPLayer` modules mirror the DeepSeek-V3
  spec as implemented in vLLM: `enorm` + `hnorm` RMSNorms → concat → linear
  `eh_proj(2H → H)` → a full decoder block → `shared_head (norm + lm_head)`.
- `DeepseekV3Model` / `Glm4MoeModel` extend `self.layers` past
  `num_hidden_layers` with MTP modules; the base `forward` still iterates
  only `self.layers[: num_hidden_layers]`. MTP is reached exclusively via a
  new `model.forward_mtp(input_ids, previous_hidden_state, past_key_values,
  position_ids, mtp_depth)` helper (lazily extends the KV cache for MTP
  layer indices).

**Generation side**
- `GenerationConfig.use_mtp: bool = False` and a new
  `GenerationMode.MTP_DECODING` routed from `get_generation_mode` whenever
  the base mode is greedy or sample.
- `_mtp_decoding` in `generation/utils.py`: main forward → sample `x_{t+1}`
  → chain K MTP depths for draft tokens → single verify forward → reuses
  `_speculative_sampling` for accept/reject → `past_key_values.crop`.
  Batch size 1, dynamic cache; leaves `_assisted_decoding` untouched.
- `ContinuousBatchingManager` refuses `use_mtp=True` for now —
  paged-attention slot reservation + per-request accept/reject is tracked
  separately and will come as a follow-up.

**Tests**
- `tests/generation/test_mtp.py` covers: mode dispatch, greedy
  token-for-token parity vs plain `_sample` for K=1/2/3 on both models,
  `num_nextn_predict_layers=0` rejection, layer extension, base-forward
  equivalence when MTP layers are added, `forward_mtp` shapes, and the
  `generate_batch` `NotImplementedError`.

All 9 MTP tests pass locally. `make style` clean. `make fix-repo` clean
apart from the pre-existing `mlinter._using_rule_specs` env mismatch in
`check_modeling_rules_doc.py` / `check_modeling_structure.py` that also
fails on an unmodified checkout.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3, glm4_moe, glm4_moe_lite, glm4v_moe, glm_moe_dsa, longcat_flash, solar_open, youtu

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Superseded by a cleaner refactor (MTP moved out of the model into MTPCandidateGenerator). Reopening from the origin branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants