Skip to content

[Draft][ATOM-SGLang][Feat] Enable Deepseek v3 MTP#643

Draft
ZhiweiYan-96 wants to merge 4 commits intoROCm:mainfrom
zejunchen-zejun:zhiwei/ds_mtp
Draft

[Draft][ATOM-SGLang][Feat] Enable Deepseek v3 MTP#643
ZhiweiYan-96 wants to merge 4 commits intoROCm:mainfrom
zejunchen-zejun:zhiwei/ds_mtp

Conversation

@ZhiweiYan-96
Copy link
Copy Markdown
Contributor

Proposed Design

1. MTP module creation: Override the draft architecture through the external model package

As background knowledge, it is helpful to first detail how SGLang loads the DeepSeek MTP module in its native flow. From SGLang's point of view, DeepSeek MTP is not an auxiliary block hidden inside the target model. It is a standalone draft architecture. The loading path is roughly:

  1. The user enables speculative decoding through server arguments such as --speculative-algorithm NEXTN
  2. SGLang normalizes NEXTN into the EAGLE runtime family in server args
  3. The draft ModelConfig rewrites the DeepSeek V3 draft architecture to DeepseekV3ForCausalLMNextN inside _config_draft_model()
  4. ModelRegistry resolves the model class by that architecture name
  5. The resolved class is instantiated as the draft model and then used by the speculative worker in the propose / verify / extend lifecycle

In other words, SGLang first interprets "DeepSeek MTP" as "a separately loaded draft model", and only then enters the runtime phase. The external model package hook works exactly at this architecture-resolution stage.

For MTP side, SGLang uses DeepseekV3ForCausalLMNextN as MTP model architecture.

  • DeepseekV3ForCausalLMNextN

The following diagram shows the native SGLang view of how the MTP module is loaded:

flowchart TD
    subgraph SGL["SGLang domain"]
        A["CLI / server args<br/>--speculative-algorithm NEXTN"]
        B["Normalize algorithm<br/>NEXTN -> EAGLE"]
        C["Build draft ModelConfig"]
        D["_config_draft_model()<br/>rewrite architecture to<br/>DeepseekV3ForCausalLMNextN"]
        E["ModelRegistry.resolve_model_cls(...)"]
        F["Instantiate draft model class"]
        G["Speculative worker uses draft model<br/>propose / verify / extend"]
    end

    A --> B
    B --> C
    C --> D
    D --> E
    E --> F
    F --> G
Loading

SGLang allows external model packages to register architectures through SGLANG_EXTERNAL_MODEL_PACKAGE and module-level EntryClass. This is also the core mechanism for ATOM SGLang plugin. The plugin uses this mechanism to expose a class with the exact architecture name expected by SGLang:

Once this class is available in the plugin package, SGLang resolves the draft architecture to the plugin implementation instead of the upstream one in sglang/srt/models/deepseek_nextn.py.

The following diagram illustrates what "overriding the draft architecture" means in practice:

flowchart TD
    subgraph SGL["SGLang domain"]
        A["launch server<br/>SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models"]
        B["Import external model package"]
        C["Read module EntryClass"]
        D["Register architecture:<br/>DeepseekV3ForCausalLMNextN"]
        E["ModelRegistry.resolve_model_cls(...)"]
        H["upstream draft implementation<br/>sglang/srt/models/deepseek_nextn.py"]
    end

    subgraph PLUGIN["ATOM SGLang Plugin domain"]
        F["plugin wrapper<br/>DeepseekV3ForCausalLMNextN"]
    end

    subgraph CORE["ATOM Core domain"]
        G["DeepSeekMTP"]
    end

    A --> B
    B --> C
    C --> D
    D --> E
    E --> F
    F --> G

    H -. "same architecture name is overridden" .-> D
Loading

The important point is that architecture resolution and ModelRegistry selection still happen inside the SGLang domain. The ATOM SGLang Plugin domain only contributes a same-name wrapper through the external package entry point, while the actual draft computation is delegated to DeepSeekMTP in the ATOM Core domain. This makes it easier to separate who owns scheduling and model resolution from who owns the draft implementation.

2. MTP Warrper in ATOM: A thin wrapper as the compatibility bridge

The plugin adds a lightweight wrapper named DeepseekV3ForCausalLMNextN. Externally, it matches the draft-model interface expected by SGLang. Internally, it delegates the actual draft computation to ATOM DeepSeekMTP.

The wrapper is responsible for:

  • creating the plugin-mode atom_config
  • rewriting configuration semantics so they match DeepSeekMTP
  • instantiating ATOM/atom/models/deepseek_mtp.py::DeepSeekMTP
  • exposing get_embed_and_head(), set_embed_and_head(), and set_embed() so speculative workers can share embeddings and LM head weights with the target model
  • consuming forward_batch.spec_info.hidden_states in forward()
  • loading weights through the spec_decode=True path

The design principle is:

  • keep the SGLang architecture name and draft-worker contract unchanged at the top layer
  • reuse ATOM DeepSeekMTP as the implementation at the lower layer

This minimizes duplication, avoids recreating the upstream NextN hierarchy inside the plugin, and makes future improvements to ATOM's native MTP implementation reusable in plugin mode.

Risks

Intrusive change to formal runtime variable control codes

ATOM core code currently relies on some process-global runtime/config state. In speculative mode, target and draft wrappers coexist. Without isolation, initializing or running the draft wrapper may overwrite global state used by the target path, leading to subtle cross-contamination in MoE or attention behavior.

To address this, the plugin introduces a runtime scope that explicitly binds and restores the proper runtime context around wrapper __init__, forward(), and load_weights(). This allows target and draft instances to coexist safely.

However, this also makes an architectural issue visible: the current plugin system still has meaningful complexity around process-global state management. In order to let multiple wrappers coexist, the plugin must repeatedly switch and restore global runtime state at execution boundaries. In that sense, runtime scoping should be understood as a containment mechanism for the current global-state model, not as the ideal long-term abstraction. It solves the correctness problem for this branch, but it also suggests a future direction toward fewer implicit globals and more explicitly instantiated runtime state.

Direct attn backend replacment sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl

The reason can be summarized by the key SGLang call chain:

# eagle_worker.py
self.draft_attn_backend = DraftBackendFactory(...).create_decode_backend()

# draft_utils.py
def _create_aiter_decode_backend(self):
    return AiterMultiStepDraftBackend(...)

# aiter_backend.py
for i in range(self.speculative_num_steps - 1):
    self.attn_backends.append(AiterAttnBackend(model_runner, ...))

In other words, EAGLE draft multi-step decode actually goes through:

flowchart LR
    A["EAGLEWorker"] --> B["DraftBackendFactory"]
    B --> C["AiterMultiStepDraftBackend"]
    C --> D["AiterAttnBackend(...)"]
    R["ATOM-sglang attention registry"] -. "not used on this path" .-> D
Loading

So if the plugin only overrides the "aiter" registry entry, but does not also rewrite:

  • sglang.srt.layers.attention.aiter_backend.AiterAttnBackend

then EAGLE draft decode still directly constructs the upstream AiterAttnBackend. That is why this monkeypatch is hacky, but still practically necessary on the current branch.

The plugin is mutating an upstream module symbol directly. This is not a clean extension point.

Others changes

Complete the radix attention forward in specualtive mode, like

  • TARGET_VERIFY
  • DRAFT_EXTEND

Accuracy

image

Acceptance ratio

image

Copilot AI review requested due to automatic review settings April 24, 2026 09:45
@ZhiweiYan-96 ZhiweiYan-96 marked this pull request as draft April 24, 2026 09:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables DeepSeek v3 MTP (NextN/EAGLE draft model) support in the ATOM SGLang plugin by introducing a draft-architecture wrapper and extending the attention backend to handle speculative modes while isolating ATOM’s process-global runtime state between target and draft models.

Changes:

  • Add a SGLang external-model DeepseekV3ForCausalLMNextN wrapper that delegates draft computation to ATOM DeepSeekMTP.
  • Introduce plugin_runtime_scope() and apply it around init/forward/load to scope global runtime/config state per wrapper instance.
  • Extend the SGLang attention backend to support speculative TARGET_VERIFY / DRAFT_EXTEND (including CUDA graph capture/replay paths) and monkeypatch the upstream AiterAttnBackend symbol for draft paths.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
atom/plugin/sglang/models/deepseek_nextn_wrapper.py New SGLang draft-model wrapper class (DeepseekV3ForCausalLMNextN) backed by ATOM DeepSeekMTP, including runtime scoping and layer-id retagging.
atom/plugin/sglang/models/base_model_wrapper.py Add plugin_runtime_scope() and use it to scope ATOM globals; add embed/head sharing helpers; scope forward + weight loading.
atom/plugin/sglang/attention_backend/sgl_attn_backend.py Add speculative-mode metadata init for MLA and extend CUDA-graph capture/replay + speculative forward path handling.
atom/plugin/sglang/attention_backend/radix_attention.py Ensure k_scale/v_scale parameters are CUDA-resident for SGLang RadixAttention usage.
atom/plugin/register.py Monkeypatch SGLang’s AiterAttnBackend symbol to route direct draft-path construction to ATOMAttnBackendForSgl.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +159 to +165
del embed_owner.embed_tokens.weight
del self.model.lm_head.weight
embed_owner.embed_tokens.weight = embed
self.model.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_embed_and_head() calls torch.cuda.empty_cache() / torch.cuda.synchronize() unconditionally. In CPU-only environments (or when CUDA isn’t initialized), this will raise. Consider guarding with torch.cuda.is_available() (as done in deepseek_nextn_wrapper._sync_replaced_weights) or syncing based on embed.device.type == "cuda".

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +178
del embed_owner.embed_tokens.weight
embed_owner.embed_tokens.weight = embed
torch.cuda.empty_cache()
torch.cuda.synchronize()
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_embed() calls torch.cuda.empty_cache() / torch.cuda.synchronize() unconditionally, which will error on CPU-only runs. Please guard these calls (e.g., torch.cuda.is_available() or embed.is_cuda) to keep the wrapper usable in non-CUDA test/mocking environments.

Copilot uses AI. Check for mistakes.
Comment on lines +1422 to +1427
num_tokens_per_bs = self.speculative_num_steps + 1
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In init_forward_metadata_replay_cuda_graph(), the draft_extend branch dereferences spec_info.accept_length but spec_info is typed as optional and isn’t validated here. Add an explicit check/raise (similar to the target_verify non-MLA branch) to avoid a hard AttributeError if spec_info is unexpectedly missing.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants