[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
Open
[Kimi] support Eagle3 speculative decoding for Kimi K2.5#631
Conversation
Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target): - New Eagle3LlamaModel (1-layer Llama draft, standard MHA) in atom/models/eagle3_llama.py - DraftKVCache abstraction for draft models whose attention format differs from target; independent block_size/num_blocks calculation - Aux hidden state collection from configurable target model layers, passed through CUDAGraph capture/replay via graph_aux_hidden dict - Eagle3LlamaModel.forward returns (post_norm, pre_norm); propose loop feeds pre_norm to next draft step to avoid double RMSNorm - Default aux layer IDs (2, N//2, N-3), aligned with vLLM - Float16 ↔ bfloat16 weight loading uses to() value conversion instead of view() bit reinterpretation - CLI args: --method eagle3, --draft-model, --eagle3-aux-layer-ids - Validation: aux-layer-ids requires --method eagle3; MLA draft raises NotImplementedError GSM8K 5-shot acceptance rate 65.9%, accuracy 93.78% on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, MI350x8, TP=8, mtp_k=3. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Eagle3 Speculative Decoding for Kimi K2.5
Closes ROCm/ATOM#553
1. Goal
Add Eagle3 speculative decoding support for Kimi K2.5 (MLA target) in ATOM:
The draft checkpoint used in this PR is
lightseekorg/kimi-k2.5-eagle3— a 1-layer standard Llama decoder (MHA). The MLA-draft variant is not yet released upstream; the abstraction introduced here is format-agnostic and can be extended to MLA draft when available.End result: GSM8K 5-shot acceptance rate 65.9%, accuracy 93.78% on Kimi-K2.5-MXFP4.
2. Design
Eagle3 differs from the existing MTP path in ATOM along four key dimensions:
fc(hidden, aux_hidden_states)tupleThree pieces of design address these differences:
DraftKVCacheabstraction: encapsulates the draft model's independentcache / scale / block_size / num_blocks. Draft and target share the same memory budget pool, but block granularity is computed independently (draft_num_blocks = num_kvcache_blocks × target_block_size / draft_block_size) so total token capacity stays consistent. The abstraction is not bound to Eagle3 — any future draft whose attention format diverges from the target can reuse it.block_size=16/1024aligned with the MHA backend, so it goes throughpaged_attention_triton→pa_decode_gluondirectly. Attention metadata (block_tables / context_lens / slot_mapping) is reused from the scheduler-populatedforward_varsonce before the propose loop, then incrementally updated withcontext_lens += 1per draft step.graph_aux_hidden: dict[(bs, max_q_len), list[Tensor]]stores references to the aux tensors materialized inside the captured graph. After replay, slices are taken to the actualnum_tokensfor the drafter to consume. Type dispatch uses ause_aux_hidden_state_outputsflag set at init time, avoiding runtimeisinstancechecks that would pollute graph capture.3. Comparison with the existing MTP workflow
(hidden, aux_hidden_states)graph_aux_hiddensaves aux references; sliced tonum_tokensafter replayDraftKVCacheallocated independently with its own block granularityforward_vars(block_tables / context_lens / slot_mapping)(post_norm, pre_norm); logits usepost_norm, next step usespre_normFlow diagram
4. Changes
atom/model_engine/model_runner.pyDraftKVCache+_slice_kv_cache_block(); Eagle3 KV cache allocation/binding; CUDAGraph capture handles tuple return +graph_aux_hidden; aux layer registrationatom/models/eagle3_llama.pyEagle3LlamaModel,Eagle3LlamaDecoderLayer;forwardreturns(post_norm, pre_norm)atom/spec_decode/eagle.pyforward_varsbefore propose loop; unpack(post_norm, pre_norm); feedpre_normto next draft stepatom/model_engine/arg_utils.py--method eagle3,--draft-model,--eagle3-aux-layer-idsatom/config.pySpeculativeConfigparses Eagle3 fields; argument validation; defaultuse_aux_hidden_state=Truewheneagle_configabsentatom/models/deepseek_v2.py(2, N//2, N-3), aligned with vLLMatom/model_ops/linear.pyto()value conversion (only same-family dtypes such as fp8 variants keepview()bit reinterpretation)atom/models/kimi_k25.pyatom/model_ops/attention_mha.pysliding_window=0routed to triton kernel5. Usage
CLI arguments
--method eagle3mtp)--draft-model <path>/data/models/kimi-k2.5-eagle3--num-speculative-tokens <int>--eagle3-aux-layer-ids <str>(2, N//2, N-3)OpenAI-compatible server
AITER_LOG_LEVEL=WARNING python -m atom.entrypoints.openai_server \ --model /data/models/Kimi-K2.5-MXFP4 \ --kv_cache_dtype fp8 \ -tp 8 \ --method eagle3 \ --draft-model /data/models/kimi-k2.5-eagle3 \ --num-speculative-tokens 3Offline inference
python -m atom.examples.simple_inference \ --model /data/models/Kimi-K2.5-MXFP4 \ --kv_cache_dtype fp8 \ -tp 8 \ --method eagle3 \ --draft-model /data/models/kimi-k2.5-eagle3 \ --num-speculative-tokens 3Verification
Acceptance rate and per-position distribution are emitted in the server log statistics line.
Fallback
--method,--draft-model,--num-speculative-tokens--enforce-eager6. Results
Test environment: Kimi-K2.5-MXFP4 +
lightseekorg/kimi-k2.5-eagle3, AMD MI350 × 8 (ROCm 7.2), TP=8,num_speculative_tokens=3,aux_layer_ids=(2,30,58), lm_eval GSM8K 5-shot.Acceptance distribution (
mtp_k=3): 0/1/2/3 accepted = 17.3% / 15.6% / 18.7% / 48.4%.vLLM reference under the same setup is 72.9%. The remaining ~7% gap is likely attributable to bf16 numerical differences between ROCm and CUDA attention kernels and to MXFP4 dequant implementation differences, both orthogonal to the scope of this PR.
7. Known limitations
kv_lora_rank, startup raisesNotImplementedError.--eagle3-aux-layer-idsis only effective when--method eagle3.