Skip to content

feat: decouple GDN recurrent state from KV cache block pool#602

Merged
valarLip merged 1 commit intomainfrom
feat/gdn-state-decoupling
Apr 18, 2026
Merged

feat: decouple GDN recurrent state from KV cache block pool#602
valarLip merged 1 commit intomainfrom
feat/gdn-state-decoupling

Conversation

@valarLip
Copy link
Copy Markdown
Collaborator

Summary

  • Decouple GDN (Gated DeltaNet) recurrent state from KV cache block pool into per-request slot allocation
  • GDN state was previously packed into every KV block (23.41 MB/block, 96% being GDN state), causing 98% memory waste
  • Unified pool with dynamic KV + mamba competition via equiv-block accounting
  • Mamba state tensor sized to max_num_seqs instead of num_physical_kvcache_blocks

Results on Qwen3.5-397B-A17B-FP8 tp=4 1024/1024 c=128:

Metric Before After Change
Pool blocks 6,906 565,016 82x
Throughput 4,657 tok/s 7,415 tok/s +59%
GSM8K accuracy N/A 0.8711 --
MTP3 throughput -- 11,089 tok/s --

Regression test (DeepSeek-R1-0528 tp=8, pure MLA): No impact (0.9575 accuracy, unchanged throughput)

Test plan

  • 24 new unit tests covering BlockManager mamba slots, Sequence fields, state index mapping, Scheduler integration
  • GSM8K accuracy: Qwen3-Next (0.8605), Qwen3.5 FP8 (0.8711), Qwen3.5 MTP3 (0.8643)
  • Performance benchmark: Qwen3.5 FP8 (+59%), Qwen3.5 MTP3 (11,089 tok/s)
  • Regression: DeepSeek-R1-0528 accuracy (0.9575) and performance unaffected
  • All Triton/CUDA kernels verified compatible (indices are dim-0 offsets, value-agnostic)

Copilot AI review requested due to automatic review settings April 18, 2026 15:51
Previously, GDN (Gated DeltaNet) recurrent state was packed into every
KV cache block (23.41 MB/block, 96% being GDN state), causing 98%
memory waste (148 GB allocated, only 2.9 GB used).

This change decouples recurrent state into a per-request slot pool that
dynamically competes with KV cache blocks in a unified pool:

- block_bytes now contains only KV cache + scale (not GDN state)
- GDN state is allocated as a separate tensor sized to max_num_seqs
- BlockManager tracks mamba memory via equiv-block accounting
- Sequence.mamba_block_table replaced with mamba_state_slot (int)
- State indices in gdn_attn use slot_group mapping instead of block_id

Results on Qwen3.5-397B-A17B-FP8 tp=4:
- Pool blocks: 6,906 → 565,016 (82x increase)
- Throughput: 4,657 → 7,415 tok/s (+59%)
- GSM8K accuracy: 0.8711 (no regression)
- DeepSeek-R1 (pure MLA): no regression (0.9575 accuracy)
@valarLip valarLip force-pushed the feat/gdn-state-decoupling branch from ba2c4a0 to 9961beb Compare April 18, 2026 15:54
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Decouples GDN (Gated DeltaNet) recurrent state from KV-cache block storage by introducing per-request state slot allocation with “equivalent block” accounting, and extends speculative decoding (MTP) support and observability (CI + dashboard) for Qwen3.5.

Changes:

  • Introduce per-request mamba_state_slot allocation and unified pool accounting (mamba_equiv_per_req, num_mamba_groups) across BlockManager, Sequence, Scheduler, and GDN attention metadata.
  • Add/extend MTP draft model support (notably Qwen3.5 MTP) and improve weight-loading + proposer metadata handling for MLA vs MHA.
  • Expand CI/dashboard reporting (model display names, MTP acceptance metrics) and add unit tests + documentation updates.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/test_mamba_state_decoupling.py New unit tests for mamba slot allocation/accounting, mapping, and scheduler integration.
recipes/Qwen3.5.md New Qwen3.5 usage/benchmarking/accuracy guide.
docs/scheduling_kv_cache_guide.md Documents mamba slot pool + updated preempt, can_allocate, can_append semantics.
docs/model_support_guide.md Adds Qwen3.5 (+ MTP) registry/documentation entries and expands MTP section.
docs/configuration_guide.md Documents auto-derived mamba/MTP config fields and table-driven MTP config mapping.
docs/architecture_guide.md Adds mamba_state_slot to Sequence fields and expands speculative decoding overview.
atom/spec_decode/eagle.py Adds Qwen3.5 MTP support, improves sharing logic using loaded-weight records, and branches metadata updates for MLA vs MHA.
atom/models/qwen3_next_mtp.py Moves MTP weight filtering/remap to a method + uses weights_mapping.
atom/models/qwen3_5_mtp.py New Qwen3.5 MTP draft model implementation.
atom/models/deepseek_mtp.py Adds remap_mtp_weight_name method for loader dispatch.
atom/model_ops/mamba_ops/causal_conv1d.py Fixes VARLEN stride selection for output token indexing.
atom/model_ops/attentions/gdn_attn.py Switches state indexing to slot-based indices and adds prepare_mtp_decode() for GDN hybrid models.
atom/model_loader/loader.py Generalizes MTP remap dispatch via model.remap_mtp_weight_name and always returns loaded_weights_record.
atom/model_engine/sequence.py Replaces mamba_block_table/num_mamba_blocks with mamba_state_slot.
atom/model_engine/scheduler.py Collects mamba_state_slots in ScheduledBatch and strips speculative placeholders on preempt.
atom/model_engine/model_runner.py Removes per-block GDN state bytes from KV blocks, introduces per-slot sizing, adjusts MTP init + KV binding + state tensor sizing.
atom/model_engine/engine_core.py Adapts to get_num_blocks() returning a dict and plumbs mamba accounting fields into config.
atom/model_engine/block_manager.py Adds mamba slot pool + equivalent-block accounting, updates can_allocate/allocate/deallocate.
atom/config.py Implements table-driven MTP config override (_MTP_TYPE_MAP, _MTP_CONFIG) and supports multimodal text_config.
CLAUDE.md Updates model reuse notes to include Qwen3.5 MTP.
.github/workflows/atom-benchmark.yaml Uses shared transform script and injects model display name into benchmark payload + publishes models map asset.
.github/scripts/plugin_benchmark_to_dashboard.py Improves fallback model-name derivation.
.github/scripts/atom_test.sh Extracts MTP acceptance rate/toks-per-forward from server logs into CI metadata.
.github/scripts/accuracy_to_dashboard.py Publishes MTP acceptance and avg toks/fwd metrics to dashboard entries.
.github/dashboard/index.html Adds model-name normalization via mapping, parses MTP metrics, and renders MTP acceptance trend in accuracy views.
.github/benchmark/models_accuracy.json Adds Qwen3.5 FP8 MTP3 accuracy entry.
.github/benchmark/models.json Adds Qwen3.5 FP8 MTP3 perf entry and normalizes some DeepSeek display naming.
.claude/commands/debug-guide.md Updates debugging guidance for MTP + compilation modes and mamba slot concepts.
.claude/commands/benchmark-guide.md Adds critical rules for benchmarking and MTP usage notes.
.claude/commands/add-model.md Updates model-add guide for multimodal wrappers + MTP support patterns (incl. Qwen3.5).
Comments suppressed due to low confidence (1)

atom/model_engine/model_runner.py:589

  • torch.set_default_device(None) is likely invalid (PyTorch expects a device string/torch.device) and can raise a TypeError at startup. Use an explicit device such as torch.set_default_device("cpu"), or restore the prior default device via a saved value, after finishing drafter construction/loading.
            torch.set_default_device(self.device)
            with set_model_tag("drafter"):
                self.drafter = EagleProposer(self.config, self.device, self)
            self.rejection_sampler = RejectionSampler()
            torch.set_default_device(None)
            logger.info("Loading drafter model...")

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@valarLip valarLip merged commit 3138d9a into main Apr 18, 2026
26 of 28 checks passed
@valarLip valarLip deleted the feat/gdn-state-decoupling branch April 18, 2026 16:30
valarLip added a commit that referenced this pull request Apr 29, 2026
* refactor: delegate ATOM KV cache subsystem to attention builders

Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095

* style: black format block_manager.py
valarLip added a commit that referenced this pull request Apr 30, 2026
Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants