feat: decouple GDN recurrent state from KV cache block pool#602
Merged
feat: decouple GDN recurrent state from KV cache block pool#602
Conversation
Previously, GDN (Gated DeltaNet) recurrent state was packed into every KV cache block (23.41 MB/block, 96% being GDN state), causing 98% memory waste (148 GB allocated, only 2.9 GB used). This change decouples recurrent state into a per-request slot pool that dynamically competes with KV cache blocks in a unified pool: - block_bytes now contains only KV cache + scale (not GDN state) - GDN state is allocated as a separate tensor sized to max_num_seqs - BlockManager tracks mamba memory via equiv-block accounting - Sequence.mamba_block_table replaced with mamba_state_slot (int) - State indices in gdn_attn use slot_group mapping instead of block_id Results on Qwen3.5-397B-A17B-FP8 tp=4: - Pool blocks: 6,906 → 565,016 (82x increase) - Throughput: 4,657 → 7,415 tok/s (+59%) - GSM8K accuracy: 0.8711 (no regression) - DeepSeek-R1 (pure MLA): no regression (0.9575 accuracy)
ba2c4a0 to
9961beb
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
Decouples GDN (Gated DeltaNet) recurrent state from KV-cache block storage by introducing per-request state slot allocation with “equivalent block” accounting, and extends speculative decoding (MTP) support and observability (CI + dashboard) for Qwen3.5.
Changes:
- Introduce per-request
mamba_state_slotallocation and unified pool accounting (mamba_equiv_per_req,num_mamba_groups) acrossBlockManager,Sequence,Scheduler, and GDN attention metadata. - Add/extend MTP draft model support (notably Qwen3.5 MTP) and improve weight-loading + proposer metadata handling for MLA vs MHA.
- Expand CI/dashboard reporting (model display names, MTP acceptance metrics) and add unit tests + documentation updates.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_mamba_state_decoupling.py |
New unit tests for mamba slot allocation/accounting, mapping, and scheduler integration. |
recipes/Qwen3.5.md |
New Qwen3.5 usage/benchmarking/accuracy guide. |
docs/scheduling_kv_cache_guide.md |
Documents mamba slot pool + updated preempt, can_allocate, can_append semantics. |
docs/model_support_guide.md |
Adds Qwen3.5 (+ MTP) registry/documentation entries and expands MTP section. |
docs/configuration_guide.md |
Documents auto-derived mamba/MTP config fields and table-driven MTP config mapping. |
docs/architecture_guide.md |
Adds mamba_state_slot to Sequence fields and expands speculative decoding overview. |
atom/spec_decode/eagle.py |
Adds Qwen3.5 MTP support, improves sharing logic using loaded-weight records, and branches metadata updates for MLA vs MHA. |
atom/models/qwen3_next_mtp.py |
Moves MTP weight filtering/remap to a method + uses weights_mapping. |
atom/models/qwen3_5_mtp.py |
New Qwen3.5 MTP draft model implementation. |
atom/models/deepseek_mtp.py |
Adds remap_mtp_weight_name method for loader dispatch. |
atom/model_ops/mamba_ops/causal_conv1d.py |
Fixes VARLEN stride selection for output token indexing. |
atom/model_ops/attentions/gdn_attn.py |
Switches state indexing to slot-based indices and adds prepare_mtp_decode() for GDN hybrid models. |
atom/model_loader/loader.py |
Generalizes MTP remap dispatch via model.remap_mtp_weight_name and always returns loaded_weights_record. |
atom/model_engine/sequence.py |
Replaces mamba_block_table/num_mamba_blocks with mamba_state_slot. |
atom/model_engine/scheduler.py |
Collects mamba_state_slots in ScheduledBatch and strips speculative placeholders on preempt. |
atom/model_engine/model_runner.py |
Removes per-block GDN state bytes from KV blocks, introduces per-slot sizing, adjusts MTP init + KV binding + state tensor sizing. |
atom/model_engine/engine_core.py |
Adapts to get_num_blocks() returning a dict and plumbs mamba accounting fields into config. |
atom/model_engine/block_manager.py |
Adds mamba slot pool + equivalent-block accounting, updates can_allocate/allocate/deallocate. |
atom/config.py |
Implements table-driven MTP config override (_MTP_TYPE_MAP, _MTP_CONFIG) and supports multimodal text_config. |
CLAUDE.md |
Updates model reuse notes to include Qwen3.5 MTP. |
.github/workflows/atom-benchmark.yaml |
Uses shared transform script and injects model display name into benchmark payload + publishes models map asset. |
.github/scripts/plugin_benchmark_to_dashboard.py |
Improves fallback model-name derivation. |
.github/scripts/atom_test.sh |
Extracts MTP acceptance rate/toks-per-forward from server logs into CI metadata. |
.github/scripts/accuracy_to_dashboard.py |
Publishes MTP acceptance and avg toks/fwd metrics to dashboard entries. |
.github/dashboard/index.html |
Adds model-name normalization via mapping, parses MTP metrics, and renders MTP acceptance trend in accuracy views. |
.github/benchmark/models_accuracy.json |
Adds Qwen3.5 FP8 MTP3 accuracy entry. |
.github/benchmark/models.json |
Adds Qwen3.5 FP8 MTP3 perf entry and normalizes some DeepSeek display naming. |
.claude/commands/debug-guide.md |
Updates debugging guidance for MTP + compilation modes and mamba slot concepts. |
.claude/commands/benchmark-guide.md |
Adds critical rules for benchmarking and MTP usage notes. |
.claude/commands/add-model.md |
Updates model-add guide for multimodal wrappers + MTP support patterns (incl. Qwen3.5). |
Comments suppressed due to low confidence (1)
atom/model_engine/model_runner.py:589
torch.set_default_device(None)is likely invalid (PyTorch expects a device string/torch.device) and can raise aTypeErrorat startup. Use an explicit device such astorch.set_default_device("cpu"), or restore the prior default device via a saved value, after finishing drafter construction/loading.
torch.set_default_device(self.device)
with set_model_tag("drafter"):
self.drafter = EagleProposer(self.config, self.device, self)
self.rejection_sampler = RejectionSampler()
torch.set_default_device(None)
logger.info("Loading drafter model...")
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
5 tasks
valarLip
added a commit
that referenced
this pull request
Apr 29, 2026
* refactor: delegate ATOM KV cache subsystem to attention builders Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner is now blind to attention type — it walks modules and dispatches; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder. ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching scheduler / block_manager / ModelRunner. New AttentionMetadataBuilder hooks (defaults are no-ops): - compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool - allocate_per_req_cache(num_slots) dict of named per-request state tensors - compute_block_bytes() per-block bytes for the KV pool budget - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV cache tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store) - build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache) Builder overrides: - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state. Naming distinguishes group (per-request unit) from slot (raw tensor index). One group occupies `slots_per_req()` contiguous slots in the underlying tensor: Sequence.mamba_state_slot -> .per_req_cache_group seq.mamba_enabled -> .has_per_req_cache batch.mamba_state_slots -> .per_req_cache_groups BlockManager.mamba_* -> .per_req_cache_* (free pool, accounting) config.mamba_equiv_per_req -> .per_req_cache_equiv_blocks config.num_mamba_groups -> .num_per_req_cache_groups ModelRunner.max_mamba_slots -> .max_per_req_cache_slots (tensor dim) Removed (moved to builders): ModelRunner._compute_mamba_per_slot_bytes ModelRunner.gated_delta_net_state_shape / _dtypes Sanity check: ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag. Verified: - tests/test_per_req_cache_decoupling.py: 24/24 pass - core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent): flexible-extract = 0.8757 +/- 0.0091 (baseline 0.8711 from #602) strict-match = 0.8605 +/- 0.0095 * style: black format block_manager.py
valarLip
added a commit
that referenced
this pull request
Apr 30, 2026
Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner is now blind to attention type — it walks modules and dispatches; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder. ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching scheduler / block_manager / ModelRunner. New AttentionMetadataBuilder hooks (defaults are no-ops): - compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool - allocate_per_req_cache(num_slots) dict of named per-request state tensors - compute_block_bytes() per-block bytes for the KV pool budget - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV cache tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store) - build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache) Builder overrides: - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state. Naming distinguishes group (per-request unit) from slot (raw tensor index). One group occupies `slots_per_req()` contiguous slots in the underlying tensor: Sequence.mamba_state_slot -> .per_req_cache_group seq.mamba_enabled -> .has_per_req_cache batch.mamba_state_slots -> .per_req_cache_groups BlockManager.mamba_* -> .per_req_cache_* (free pool, accounting) config.mamba_equiv_per_req -> .per_req_cache_equiv_blocks config.num_mamba_groups -> .num_per_req_cache_groups ModelRunner.max_mamba_slots -> .max_per_req_cache_slots (tensor dim) Removed (moved to builders): ModelRunner._compute_mamba_per_slot_bytes ModelRunner.gated_delta_net_state_shape / _dtypes Sanity check: ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag. Verified: - tests/test_per_req_cache_decoupling.py: 24/24 pass - core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent): flexible-extract = 0.8757 +/- 0.0091 (baseline 0.8711 from #602) strict-match = 0.8605 +/- 0.0095
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
max_num_seqsinstead ofnum_physical_kvcache_blocksResults on Qwen3.5-397B-A17B-FP8 tp=4 1024/1024 c=128:
Regression test (DeepSeek-R1-0528 tp=8, pure MLA): No impact (0.9575 accuracy, unchanged throughput)
Test plan