Skip to content

[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631

Open
yhl-amd wants to merge 1 commit intoROCm:mainfrom
yhl-amd:support_kimi_draft
Open

[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
yhl-amd wants to merge 1 commit intoROCm:mainfrom
yhl-amd:support_kimi_draft

Conversation

@yhl-amd
Copy link
Copy Markdown

@yhl-amd yhl-amd commented Apr 22, 2026

Eagle3 Speculative Decoding for Kimi K2.5

Closes ROCm/ATOM#553

1. Goal

Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target) in ATOM:

  • Functional: Eagle3 draft (standard MHA, independent KV cache, aux hidden state feedback) interoperates correctly with the MLA target model
  • Performance: target model can use CUDAGraph in Eagle3 mode, on the same fast path as MTP

The draft checkpoint used in this PR is lightseekorg/kimi-k2.5-eagle3 — a 1-layer standard Llama decoder (MHA). The MLA-draft variant is not yet released upstream; the abstraction introduced here is format-agnostic and can be extended to MLA draft when available.

End result: GSM8K 5-shot acceptance rate 65.9%, accuracy 93.78% on Kimi-K2.5-MXFP4.

2. Design

Eagle3 differs from the existing MTP path in ATOM along four key dimensions:

Aspect MTP Eagle3
Draft attention Same as target (MLA) Standard MHA
Draft input Target's last hidden Concatenated aux hidden states from multiple target layers, projected through fc
KV cache Shared with target Allocated separately
Forward return Single tensor (hidden, aux_hidden_states) tuple

Three pieces of design address these differences:

  • DraftKVCache abstraction: encapsulates the draft model's independent cache / scale / block_size / num_blocks. Draft and target share the same memory budget pool, but block granularity is computed independently (draft_num_blocks = num_kvcache_blocks × target_block_size / draft_block_size) so total token capacity stays consistent. The abstraction is not bound to Eagle3 — any future draft whose attention format diverges from the target can reuse it.
  • Standard paged attention path for draft: the draft uses block_size=16/1024 aligned with the MHA backend, so it goes through paged_attention_tritonpa_decode_gluon directly. Attention metadata (block_tables / context_lens / slot_mapping) is reused from the scheduler-populated forward_vars once before the propose loop, then incrementally updated with context_lens += 1 per draft step.
  • Aux hidden state through CUDAGraph: a new graph_aux_hidden: dict[(bs, max_q_len), list[Tensor]] stores references to the aux tensors materialized inside the captured graph. After replay, slices are taken to the actual num_tokens for the drafter to consume. Type dispatch uses a use_aux_hidden_state_outputs flag set at init time, avoiding runtime isinstance checks that would pollute graph capture.

3. Comparison with the existing MTP workflow

Aspect MTP (existing) Eagle3 (this PR)
Target forward return Single tensor (hidden, aux_hidden_states)
Target CUDAGraph Captured directly Same capture path + graph_aux_hidden saves aux references; sliced to num_tokens after replay
Draft KV cache Shares physical memory with target DraftKVCache allocated independently with its own block granularity
Draft attention metadata Shared MLA metadata Reused from forward_vars (block_tables / context_lens / slot_mapping)
Draft cross-step hidden Last hidden (same tensor) Pre-norm hidden at midlayer exit (avoids second RMSNorm at decoder entry)
Draft model forward return Single tensor (post_norm, pre_norm); logits use post_norm, next step uses pre_norm

Flow diagram

                  ┌──────────────────────────────────────┐
 Target forward   │  CUDAGraph capture / replay          │  ──► hidden
                  │  + graph_aux_hidden[(bs, q)] saves   │  ──► aux_hidden
                  └──────────────────────────────────────┘      (sliced to num_tokens after replay)
                                                            │
        ┌───────────────────────────────────────────────────┘
        ▼
  ┌─────────────────────────────────────────────────────────────────┐
  │ propose loop (i = 0 .. mtp_k-1)                                 │
  │                                                                 │
  │  ┌──────────────────────────────────────────────────┐           │
  │  │ Once: block_tables / context_lens / slot_mapping │  before   │
  │  │       ← forward_vars (scheduler-populated)       │   loop    │
  │  └──────────────────────────────────────────────────┘           │
  │                  ▼                                              │
  │  Draft forward  ─►  paged_attention_triton (block_size=16)      │
  │                  ─►  (post_norm, pre_norm) two tensors          │
  │                  ▼                                              │
  │  logits     ◄── post_norm                                       │
  │  next hidden ◄── pre_norm  (midlayer exit, before final norm)   │
  │  context_lens += 1  (incremental, no rebuild)                   │
  └─────────────────────────────────────────────────────────────────┘

4. Changes

File Lines Description
atom/model_engine/model_runner.py +247 DraftKVCache + _slice_kv_cache_block(); Eagle3 KV cache allocation/binding; CUDAGraph capture handles tuple return + graph_aux_hidden; aux layer registration
atom/models/eagle3_llama.py +303 (new) Eagle3LlamaModel, Eagle3LlamaDecoderLayer; forward returns (post_norm, pre_norm)
atom/spec_decode/eagle.py +152 Pull attn metadata from forward_vars before propose loop; unpack (post_norm, pre_norm); feed pre_norm to next draft step
atom/model_engine/arg_utils.py +40 CLI args --method eagle3, --draft-model, --eagle3-aux-layer-ids
atom/config.py +35 SpeculativeConfig parses Eagle3 fields; argument validation; default use_aux_hidden_state=True when eagle_config absent
atom/models/deepseek_v2.py +24 Aux hidden state collection hooks; default layer IDs (2, N//2, N-3), aligned with vLLM
atom/model_ops/linear.py +18 -5 float16↔bfloat16 weight loading uses to() value conversion (only same-family dtypes such as fp8 variants keep view() bit reinterpretation)
atom/models/kimi_k25.py +6 Register Eagle3 aux layer interface
atom/model_ops/attention_mha.py +4 -1 sliding_window=0 routed to triton kernel

5. Usage

CLI arguments

Argument Required Description
--method eagle3 yes Enable Eagle3 speculative decoding (mutually exclusive with mtp)
--draft-model <path> yes Path to the Eagle3 draft model, e.g. /data/models/kimi-k2.5-eagle3
--num-speculative-tokens <int> yes Number of autoregressive draft steps per iteration (recommended: 3)
--eagle3-aux-layer-ids <str> no Comma-separated target layer indices for aux hidden state; defaults to (2, N//2, N-3)

OpenAI-compatible server

AITER_LOG_LEVEL=WARNING python -m atom.entrypoints.openai_server \
    --model /data/models/Kimi-K2.5-MXFP4 \
    --kv_cache_dtype fp8 \
    -tp 8 \
    --method eagle3 \
    --draft-model /data/models/kimi-k2.5-eagle3 \
    --num-speculative-tokens 3

Offline inference

python -m atom.examples.simple_inference \
    --model /data/models/Kimi-K2.5-MXFP4 \
    --kv_cache_dtype fp8 \
    -tp 8 \
    --method eagle3 \
    --draft-model /data/models/kimi-k2.5-eagle3 \
    --num-speculative-tokens 3

Verification

# Server startup: HTTP /health is not sufficient — confirm GPU memory is actually allocated
rocm-smi --showmemuse        # VRAM% > 0 means model is loaded

# Run lm_eval to verify acceptance rate and accuracy
lm_eval --model local-completions \
    --model_args base_url=http://localhost:8000/v1/completions,model=Kimi-K2.5-MXFP4 \
    --tasks gsm8k_cot_zeroshot --num_fewshot 5 --batch_size 64

Acceptance rate and per-position distribution are emitted in the server log statistics line.

Fallback

  • Disable speculative decoding: drop --method, --draft-model, --num-speculative-tokens
  • Keep speculation but disable CUDAGraph: add --enforce-eager

6. Results

Test environment: Kimi-K2.5-MXFP4 + lightseekorg/kimi-k2.5-eagle3, AMD MI350 × 8 (ROCm 7.2), TP=8, num_speculative_tokens=3, aux_layer_ids=(2,30,58), lm_eval GSM8K 5-shot.

Metric ATOM
Acceptance rate 65.9%
Avg tokens / forward 2.98
GSM8K accuracy 93.78%

Acceptance distribution (mtp_k=3): 0/1/2/3 accepted = 17.3% / 15.6% / 18.7% / 48.4%.

vLLM reference under the same setup is 72.9%. The remaining ~7% gap is likely attributable to bf16 numerical differences between ROCm and CUDA attention kernels and to MXFP4 dequant implementation differences, both orthogonal to the scope of this PR.

7. Known limitations

  1. Draft only supports standard MHA (limited by the upstream checkpoint). If the draft config has a non-empty kv_lora_rank, startup raises NotImplementedError.
  2. --eagle3-aux-layer-ids is only effective when --method eagle3.
  3. The draft model itself still runs eager (only the target uses CUDAGraph).

Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target):
- New Eagle3LlamaModel (1-layer Llama draft, standard MHA) in
  atom/models/eagle3_llama.py
- DraftKVCache abstraction for draft models whose attention format
  differs from target; independent block_size/num_blocks calculation
- Aux hidden state collection from configurable target model layers,
  passed through CUDAGraph capture/replay via graph_aux_hidden dict
- Eagle3LlamaModel.forward returns (post_norm, pre_norm); propose loop
  feeds pre_norm to next draft step to avoid double RMSNorm
- Default aux layer IDs (2, N//2, N-3), aligned with vLLM
- Float16 ↔ bfloat16 weight loading uses to() value conversion
  instead of view() bit reinterpretation
- CLI args: --method eagle3, --draft-model, --eagle3-aux-layer-ids
- Validation: aux-layer-ids requires --method eagle3; MLA draft raises
  NotImplementedError

GSM8K 5-shot acceptance rate 65.9%, accuracy 93.78% on
Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, MI350x8, TP=8, mtp_k=3.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: ROCm Kimi K2.5 EAGLE3 MTP heads

1 participant