[Draft][ATOM-SGLang][Feat] Enable Deepseek v3 MTP#643
[Draft][ATOM-SGLang][Feat] Enable Deepseek v3 MTP#643ZhiweiYan-96 wants to merge 4 commits intoROCm:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Enables DeepSeek v3 MTP (NextN/EAGLE draft model) support in the ATOM SGLang plugin by introducing a draft-architecture wrapper and extending the attention backend to handle speculative modes while isolating ATOM’s process-global runtime state between target and draft models.
Changes:
- Add a SGLang external-model
DeepseekV3ForCausalLMNextNwrapper that delegates draft computation to ATOMDeepSeekMTP. - Introduce
plugin_runtime_scope()and apply it around init/forward/load to scope global runtime/config state per wrapper instance. - Extend the SGLang attention backend to support speculative
TARGET_VERIFY/DRAFT_EXTEND(including CUDA graph capture/replay paths) and monkeypatch the upstreamAiterAttnBackendsymbol for draft paths.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/plugin/sglang/models/deepseek_nextn_wrapper.py | New SGLang draft-model wrapper class (DeepseekV3ForCausalLMNextN) backed by ATOM DeepSeekMTP, including runtime scoping and layer-id retagging. |
| atom/plugin/sglang/models/base_model_wrapper.py | Add plugin_runtime_scope() and use it to scope ATOM globals; add embed/head sharing helpers; scope forward + weight loading. |
| atom/plugin/sglang/attention_backend/sgl_attn_backend.py | Add speculative-mode metadata init for MLA and extend CUDA-graph capture/replay + speculative forward path handling. |
| atom/plugin/sglang/attention_backend/radix_attention.py | Ensure k_scale/v_scale parameters are CUDA-resident for SGLang RadixAttention usage. |
| atom/plugin/register.py | Monkeypatch SGLang’s AiterAttnBackend symbol to route direct draft-path construction to ATOMAttnBackendForSgl. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| del embed_owner.embed_tokens.weight | ||
| del self.model.lm_head.weight | ||
| embed_owner.embed_tokens.weight = embed | ||
| self.model.lm_head.weight = head | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.synchronize() | ||
|
|
There was a problem hiding this comment.
set_embed_and_head() calls torch.cuda.empty_cache() / torch.cuda.synchronize() unconditionally. In CPU-only environments (or when CUDA isn’t initialized), this will raise. Consider guarding with torch.cuda.is_available() (as done in deepseek_nextn_wrapper._sync_replaced_weights) or syncing based on embed.device.type == "cuda".
| del embed_owner.embed_tokens.weight | ||
| embed_owner.embed_tokens.weight = embed | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
set_embed() calls torch.cuda.empty_cache() / torch.cuda.synchronize() unconditionally, which will error on CPU-only runs. Please guard these calls (e.g., torch.cuda.is_available() or embed.is_cuda) to keep the wrapper usable in non-CUDA test/mocking environments.
| num_tokens_per_bs = self.speculative_num_steps + 1 | ||
| seq_lens = seq_lens[:bs] | ||
| accept_lens = spec_info.accept_length[:bs] | ||
| qo_indptr = self.qo_indptr[: bs + 1] | ||
| qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0) | ||
| kv_indptr = self.kv_indptr[: bs + 1] |
There was a problem hiding this comment.
In init_forward_metadata_replay_cuda_graph(), the draft_extend branch dereferences spec_info.accept_length but spec_info is typed as optional and isn’t validated here. Add an explicit check/raise (similar to the target_verify non-MLA branch) to avoid a hard AttributeError if spec_info is unexpectedly missing.
Proposed Design
1. MTP module creation: Override the draft architecture through the external model package
As background knowledge, it is helpful to first detail how
SGLangloads the DeepSeek MTP module in its native flow. FromSGLang's point of view, DeepSeek MTP is not an auxiliary block hidden inside the target model. It is a standalone draft architecture. The loading path is roughly:--speculative-algorithm NEXTNSGLangnormalizesNEXTNinto theEAGLEruntime family in server argsModelConfigrewrites the DeepSeek V3 draft architecture toDeepseekV3ForCausalLMNextNinside_config_draft_model()ModelRegistryresolves the model class by that architecture nameIn other words,
SGLangfirst interprets "DeepSeek MTP" as "a separately loaded draft model", and only then enters the runtime phase. The external model package hook works exactly at this architecture-resolution stage.For MTP side, SGLang uses
DeepseekV3ForCausalLMNextNas MTP model architecture.DeepseekV3ForCausalLMNextNThe following diagram shows the native
SGLangview of how the MTP module is loaded:flowchart TD subgraph SGL["SGLang domain"] A["CLI / server args<br/>--speculative-algorithm NEXTN"] B["Normalize algorithm<br/>NEXTN -> EAGLE"] C["Build draft ModelConfig"] D["_config_draft_model()<br/>rewrite architecture to<br/>DeepseekV3ForCausalLMNextN"] E["ModelRegistry.resolve_model_cls(...)"] F["Instantiate draft model class"] G["Speculative worker uses draft model<br/>propose / verify / extend"] end A --> B B --> C C --> D D --> E E --> F F --> GSGLangallows external model packages to register architectures throughSGLANG_EXTERNAL_MODEL_PACKAGEand module-levelEntryClass. This is also the core mechanism for ATOM SGLang plugin. The plugin uses this mechanism to expose a class with the exact architecture name expected bySGLang:Once this class is available in the plugin package,
SGLangresolves the draft architecture to the plugin implementation instead of the upstream one insglang/srt/models/deepseek_nextn.py.The following diagram illustrates what "overriding the draft architecture" means in practice:
flowchart TD subgraph SGL["SGLang domain"] A["launch server<br/>SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models"] B["Import external model package"] C["Read module EntryClass"] D["Register architecture:<br/>DeepseekV3ForCausalLMNextN"] E["ModelRegistry.resolve_model_cls(...)"] H["upstream draft implementation<br/>sglang/srt/models/deepseek_nextn.py"] end subgraph PLUGIN["ATOM SGLang Plugin domain"] F["plugin wrapper<br/>DeepseekV3ForCausalLMNextN"] end subgraph CORE["ATOM Core domain"] G["DeepSeekMTP"] end A --> B B --> C C --> D D --> E E --> F F --> G H -. "same architecture name is overridden" .-> DThe important point is that architecture resolution and
ModelRegistryselection still happen inside theSGLangdomain. TheATOM SGLang Plugindomain only contributes a same-name wrapper through the external package entry point, while the actual draft computation is delegated toDeepSeekMTPin theATOM Coredomain. This makes it easier to separate who owns scheduling and model resolution from who owns the draft implementation.2. MTP Warrper in ATOM: A thin wrapper as the compatibility bridge
The plugin adds a lightweight wrapper named
DeepseekV3ForCausalLMNextN. Externally, it matches the draft-model interface expected bySGLang. Internally, it delegates the actual draft computation toATOM DeepSeekMTP.The wrapper is responsible for:
atom_configDeepSeekMTPATOM/atom/models/deepseek_mtp.py::DeepSeekMTPget_embed_and_head(),set_embed_and_head(), andset_embed()so speculative workers can share embeddings and LM head weights with the target modelforward_batch.spec_info.hidden_statesinforward()spec_decode=TruepathThe design principle is:
SGLangarchitecture name and draft-worker contract unchanged at the top layerATOM DeepSeekMTPas the implementation at the lower layerThis minimizes duplication, avoids recreating the upstream NextN hierarchy inside the plugin, and makes future improvements to
ATOM's native MTP implementation reusable in plugin mode.Risks
Intrusive change to formal runtime variable control codes
ATOMcore code currently relies on some process-global runtime/config state. In speculative mode, target and draft wrappers coexist. Without isolation, initializing or running the draft wrapper may overwrite global state used by the target path, leading to subtle cross-contamination in MoE or attention behavior.To address this, the plugin introduces a runtime scope that explicitly binds and restores the proper runtime context around wrapper
__init__,forward(), andload_weights(). This allows target and draft instances to coexist safely.However, this also makes an architectural issue visible: the current plugin system still has meaningful complexity around process-global state management. In order to let multiple wrappers coexist, the plugin must repeatedly switch and restore global runtime state at execution boundaries. In that sense,
runtime scopingshould be understood as a containment mechanism for the current global-state model, not as the ideal long-term abstraction. It solves the correctness problem for this branch, but it also suggests a future direction toward fewer implicit globals and more explicitly instantiated runtime state.Direct attn backend replacment
sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSglThe reason can be summarized by the key
SGLangcall chain:In other words,
EAGLEdraft multi-step decode actually goes through:flowchart LR A["EAGLEWorker"] --> B["DraftBackendFactory"] B --> C["AiterMultiStepDraftBackend"] C --> D["AiterAttnBackend(...)"] R["ATOM-sglang attention registry"] -. "not used on this path" .-> DSo if the plugin only overrides the
"aiter"registry entry, but does not also rewrite:sglang.srt.layers.attention.aiter_backend.AiterAttnBackendthen
EAGLEdraft decode still directly constructs the upstreamAiterAttnBackend. That is why this monkeypatch is hacky, but still practically necessary on the current branch.The plugin is mutating an upstream module symbol directly. This is not a clean extension point.
Others changes
Complete the radix attention forward in specualtive mode, like
TARGET_VERIFYDRAFT_EXTENDAccuracy
Acceptance ratio