diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 16853d6ef..d020dd2aa 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -23,21 +23,28 @@ def _register_custom_attention_to_sglang() -> None: sglang only accepts pre-registered backend names, so we reuse the "aiter" name to inject ATOMAttnBackendForSgl without modifying sglang source. """ + import sglang.srt.layers.attention.aiter_backend as sglang_aiter_backend + from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) + from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( + ATOMAttnBackendForSgl, + ) # here register the custom attention backend with the name "aiter" # as sglang defines the fixed attention backend choices, which must be # in-tree logger.info("Register custom attention backend ATOMAttnBackendForSgl to SGLang") + # Speculative draft paths instantiate AiterAttnBackend directly inside + # AiterMultiStepDraftBackend, bypassing the attention registry. Rebind the + # module symbol as well so both registry lookup and direct construction use + # the plugin backend. + sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl + @register_attention_backend("aiter") def create_atom_backend(runner): - from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( - ATOMAttnBackendForSgl, - ) - return ATOMAttnBackendForSgl(runner) diff --git a/atom/plugin/sglang/attention_backend/radix_attention.py b/atom/plugin/sglang/attention_backend/radix_attention.py index 329b719ec..dc822b492 100644 --- a/atom/plugin/sglang/attention_backend/radix_attention.py +++ b/atom/plugin/sglang/attention_backend/radix_attention.py @@ -88,17 +88,20 @@ def __init__( self.attn.k_scale = atom_parameter( torch.tensor([1.0], dtype=torch.float32, device="cuda") ) + elif not self.attn.k_scale.is_cuda: + self.attn.k_scale = torch.nn.Parameter( + self.attn.k_scale.detach().to(device="cuda"), + requires_grad=False, + ) if self.attn.v_scale is None: self.attn.v_scale = atom_parameter( torch.tensor([1.0], dtype=torch.float32, device="cuda") ) - # Some SGLang attention backends consume the host-side float scales - # directly. Keep them in sync with the device-side defaults so the - # plugin path works even when checkpoint loading never populates them. - if self.attn.k_scale_float is None: - self.attn.k_scale_float = 1.0 - if self.attn.v_scale_float is None: - self.attn.v_scale_float = 1.0 + elif not self.attn.v_scale.is_cuda: + self.attn.v_scale = torch.nn.Parameter( + self.attn.v_scale.detach().to(device="cuda"), + requires_grad=False, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py index ad6723570..4b4f82972 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py @@ -20,7 +20,10 @@ import sglang.srt.layers.attention.aiter_backend as _sglang_aiter from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend -from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.attention.utils import ( + create_flashinfer_kv_indices_triton, + pad_sequence_with_mask, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import get_bool_env_var @@ -190,6 +193,7 @@ class ForwardMetadata: reduce_partial_map: Optional[torch.Tensor] = None fp8_prefill_kv_indices: Optional[torch.Tensor] = None num_kv_splits: Optional[int] = None + run_graph: Optional[bool] = True # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) pa_metadata_qo_indptr: Optional[torch.Tensor] = None pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None @@ -212,8 +216,9 @@ def __init__( model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, + topk: int = 1, ): - super().__init__(model_runner, skip_prefill, kv_indptr_buf) + super().__init__(model_runner, skip_prefill, kv_indptr_buf, topk) mapping = getattr( model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None ) @@ -283,6 +288,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode_or_idle(): self._init_forward_metadata_decode(forward_batch) + elif self.use_mla and forward_batch.forward_mode.is_draft_extend(): + self._init_draft_extend_mla(forward_batch.batch_size, forward_batch) + elif self.use_mla and forward_batch.forward_mode.is_target_verify(): + self._init_target_verify_mla(forward_batch.batch_size, forward_batch) else: self._init_forward_metadata_extend(forward_batch) self._fixup_page_table(forward_batch) @@ -429,6 +438,188 @@ def _init_forward_metadata_extend(self, forward_batch: ForwardBatch): else: self._init_extend_mha(bs, forward_batch) + def _init_draft_extend_mla(self, bs, forward_batch): + """Init MLA metadata for speculative draft_extend.""" + spec_info = forward_batch.spec_info + if spec_info is None: + raise RuntimeError("MLA draft_extend requires speculative metadata") + + kv_indices, kv_indptr, qo_indptr, _ = spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + self.req_to_token, + ) + + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + if extend_seq_lens_cpu is not None: + max_q_len = ( + int(extend_seq_lens_cpu.max().item()) + if isinstance(extend_seq_lens_cpu, torch.Tensor) + else max(extend_seq_lens_cpu) + ) + elif forward_batch.extend_seq_lens is not None: + max_q_len = int(forward_batch.extend_seq_lens.max().item()) + elif getattr(spec_info, "accept_length", None) is not None: + max_q_len = int(spec_info.accept_length.max().item()) + else: + raise RuntimeError( + "MLA draft_extend is missing extend sequence lengths" + ) + + seq_lens_cpu = forward_batch.seq_lens_cpu + max_kv_len = ( + ( + int(seq_lens_cpu.max().item()) + if isinstance(seq_lens_cpu, torch.Tensor) + else max(seq_lens_cpu) + ) + if seq_lens_cpu is not None + else int(forward_batch.seq_lens.max().item()) + ) + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + self.kv_last_page_len[:bs], + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + max_kv_len, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + ) + + def _init_target_verify_mla(self, bs, forward_batch): + """Init MLA metadata for speculative target_verify.""" + spec_info = forward_batch.spec_info + if spec_info is None: + raise RuntimeError("MLA target_verify requires speculative metadata") + + draft_num = spec_info.draft_token_num + kv_lens = forward_batch.seq_lens + draft_num + kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs + device = forward_batch.seq_lens.device + + qo_indptr = torch.arange( + 0, + (1 + bs) * draft_num, + step=draft_num, + dtype=torch.int32, + device=device, + ) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + kv_lens_sum, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + kv_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(draft_num, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + self.kv_last_page_len[:bs], + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + draft_num, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + draft_num, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + ) + def _init_extend_mla(self, bs, forward_batch): self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -711,6 +902,11 @@ def init_cuda_graph_state( self.cuda_graph_kv_last_page_len = torch.ones( max_bs, dtype=torch.int, device=self.device ) + assert self.cuda_graph_kv_last_page_len.is_cuda, ( + "ATOMAttnBackendForSgl.init_cuda_graph_state created " + f"non-CUDA cuda_graph_kv_last_page_len on {self.cuda_graph_kv_last_page_len.device}, " + f"backend={type(self)}" + ) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), @@ -838,27 +1034,230 @@ def init_forward_metadata_capture_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], ): - if not forward_mode.is_decode_or_idle(): - raise ValueError(f"Invalid mode: {forward_mode=}") + num_kv_splits = None + work_metadata = None + work_info_set = None + work_indptr = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None - if self.use_mla: - self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) - else: - page_table = self.page_table[:bs, :] - self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) - seq_lens_persistent = self.seq_lens[:bs] - self.forward_metadata = ForwardMetadata( - None, - None, - None, + if forward_mode.is_decode_or_idle(): + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table, + seq_lens_persistent, + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode( + bs, tp_q_head_num=self.num_head + ) + elif forward_mode.is_target_verify(): + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_lens = seq_lens + self.num_draft_tokens if self.use_mla else seq_lens + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + kv_lens, + kv_indptr, None, - 1, + kv_indices, + self.req_to_token.stride(0), + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens + + if self.use_mla: + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph TARGET_VERIFY produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}, metadata_backend={type(self.forward_metadata)}" + ) + else: + custom_mask = self.cuda_graph_custom_mask + assert spec_info is not None and spec_info.custom_mask is not None + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = max_q_len * (seq_lens + max_q_len) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=max_q_len, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph TARGET_VERIFY(non-MLA) produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + step=num_tokens_per_bs, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, None, - page_table, - seq_lens_persistent, + kv_indices, + self.req_to_token.stride(0), ) - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + if self.use_mla: + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph DRAFT_EXTEND produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + None, + num_tokens_per_bs, + None, + None, + None, + custom_mask=None, + mask_indptr=None, + max_extend_len=num_tokens_per_bs, + ) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def init_forward_metadata_replay_cuda_graph( self, @@ -872,40 +1271,240 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, ): - if not forward_mode.is_decode_or_idle(): - raise ValueError("Invalid forward mode") + num_kv_splits = None + work_metadata = None + work_info_set = None + work_indptr = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None - if self.use_mla: - self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) - else: - page_table_persistent = self.page_table - seq_lens_persistent = self.seq_lens - seq_lens_persistent.fill_(0) - page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) - max_seq_pages = ( - seq_lens_cpu.max().item() + self.page_size - 1 - ) // self.page_size + 1 - page_table = self.req_to_token[ - req_pool_indices[:, None], - self.strided_indices[:max_seq_pages][None, :], - ] - page_table_persistent[:bs, :max_seq_pages].copy_( - page_table // self.page_size, non_blocking=True - ) + if forward_mode.is_decode_or_idle(): + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) - self.forward_metadata = ForwardMetadata( - None, - None, - None, + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode( + bs, tp_q_head_num=self.num_head + ) + elif forward_mode.is_target_verify(): + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_lens = seq_lens + self.num_draft_tokens if self.use_mla else seq_lens + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + kv_lens, + kv_indptr, None, - 1, + kv_indices, + self.req_to_token.stride(0), + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens + + if self.use_mla: + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph TARGET_VERIFY produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}, metadata_backend={type(self.forward_metadata)}" + ) + else: + custom_mask = self.cuda_graph_custom_mask + assert spec_info is not None and spec_info.custom_mask is not None + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = max_q_len * (seq_lens + max_q_len) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=max_q_len, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph TARGET_VERIFY(non-MLA) produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + seq_lens = seq_lens[:bs] + accept_lens = spec_info.accept_length[:bs] + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, None, - page_table_persistent[:bs, :max_seq_pages], - seq_lens_persistent[:bs], + kv_indices, + self.req_to_token.stride(0), ) - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + if self.use_mla: + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph DRAFT_EXTEND produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + None, + num_tokens_per_bs, + None, + None, + None, + custom_mask=None, + mask_indptr=None, + max_extend_len=num_tokens_per_bs, + ) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def set_kv_buffer_with_layout_shuffle( self, @@ -1019,6 +1618,7 @@ def _forward_extend_mla(self, q, k, v, layer, forward_batch): layer, K_Buffer, qo_indptr, + forward_batch, ) if not forward_batch.forward_mode.is_extend(): raise ValueError( @@ -1356,14 +1956,51 @@ def _call_mla_decode_fwd(self, q, k_buffer, o, layer): num_kv_splits=md.num_kv_splits, ) - def _forward_extend_mla_speculative(self, q, layer, K_Buffer, qo_indptr): + def _forward_extend_mla_speculative( + self, q, layer, K_Buffer, qo_indptr, forward_batch + ): """MLA speculative path (target_verify / draft_extend).""" - o = q.new_empty( - (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), - dtype=self.input_dtype, + md = self.forward_metadata + + if forward_batch.forward_mode.is_target_verify(): + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + return o + + if forward_batch.forward_mode.is_draft_extend(): + if md.run_graph is not True: + bs, q_pad, _ = pad_sequence_with_mask( + q.view(q.shape[0], -1), + qo_indptr[:-1], + forward_batch.extend_seq_lens, + md.max_q_len, + ) + o = q.new_empty( + (bs * md.max_q_len, layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd( + q_pad.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + K_Buffer, + o, + layer, + ) + total_valid_q = int(qo_indptr[-1].item()) + return o[:total_valid_q] + + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + return o + + raise ValueError( + f"Invalid forward mode for MLA speculative path: {forward_batch.forward_mode=}" ) - self._call_mla_decode_fwd(q, K_Buffer, o, layer) - return o def forward_decode( self, diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 49ceacaf1..9182e0fa6 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -7,6 +7,7 @@ """ import logging +from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Iterable, Optional, Tuple, Union @@ -20,6 +21,8 @@ logger = logging.getLogger("atom.plugin.sglang.models") +_RUNTIME_SENTINEL = object() + # Context for patched DeepSeek attention layers that need wrapper state without # changing every intermediate forward signature. ContextVar keeps nested or # concurrent forwards isolated and lets us reliably restore the prior value. @@ -32,6 +35,37 @@ def get_current_forward_batch(): return _current_forward_batch.get() +@contextmanager +def plugin_runtime_scope( + *, + framework: Optional[str] = None, + atom_config: Any = _RUNTIME_SENTINEL, +): + """Temporarily bind plugin runtime globals to one wrapper instance. + + ATOM core currently relies on process-global framework/config state. In + SGLang speculative mode both target and draft wrappers coexist, so plugin + entrypoints must save/restore those globals around each init/load/forward. + """ + + import atom.config as atom_config_module + import atom.plugin.prepare as plugin_prepare + + prev_framework = plugin_prepare._CURRENT_FRAMEWORK + prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) + + if framework is not None: + plugin_prepare._set_framework_backbone(framework) + if atom_config is not _RUNTIME_SENTINEL: + atom_config_module._current_atom_config = atom_config + + try: + yield + finally: + plugin_prepare._CURRENT_FRAMEWORK = prev_framework + atom_config_module._current_atom_config = prev_atom_config + + _MODEL_NAMES = [ "DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", @@ -72,7 +106,14 @@ def __init__( # Refactor so this wrapper only dispatches the attention backend # (register_ops_to_sglang + set_attn_cls), and let sglang handle # model construction directly - self.model = atom.prepare_model(config=config, engine="sglang") + with plugin_runtime_scope(framework="sglang"): + from atom.config import get_current_atom_config + + self.model = atom.prepare_model(config=config, engine="sglang") + self.atom_config = getattr(self.model, "atom_config", None) + if self.atom_config is None: + self.atom_config = get_current_atom_config() + self.model.atom_config = self.atom_config if self.model is None: model_arch = getattr(config, "architectures", ["unknown"])[0] raise ValueError( @@ -90,7 +131,51 @@ def __init__( setup_deepseek_for_sglang, ) - setup_deepseek_for_sglang(self.model) + with plugin_runtime_scope( + framework="sglang", atom_config=self.atom_config + ): + setup_deepseek_for_sglang(self.model) + + def get_embed_and_head(self): + if hasattr(self.model, "get_embed_and_head"): + return self.model.get_embed_and_head() + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + return embed_owner.embed_tokens.weight, self.model.lm_head.weight + + def set_embed_and_head(self, embed, head): + if hasattr(self.model, "set_embed_and_head"): + return self.model.set_embed_and_head(embed, head) + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + del embed_owner.embed_tokens.weight + del self.model.lm_head.weight + embed_owner.embed_tokens.weight = embed + self.model.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def set_embed(self, embed): + if hasattr(self.model, "set_embed"): + return self.model.set_embed(embed) + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + del embed_owner.embed_tokens.weight + embed_owner.embed_tokens.weight = embed + torch.cuda.empty_cache() + torch.cuda.synchronize() @torch.no_grad() def forward( @@ -103,35 +188,36 @@ def forward( pp_proxy_tensors: Optional[PPProxyTensors] = None, **model_kwargs: Any, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - model_inputs = dict( - input_ids=input_ids, - positions=positions, - intermediate_tensors=pp_proxy_tensors, - inputs_embeds=input_embeds, - ) - if self._uses_forward_batch_context: - token = _current_forward_batch.set(forward_batch) - try: - hidden_states = self.model(**model_inputs) - finally: - _current_forward_batch.reset(token) - else: - hidden_states = self.model( - **model_inputs, - forward_batch=forward_batch, - get_embedding=get_embedding, - pp_proxy_tensors=pp_proxy_tensors, - **model_kwargs, - ) - - if self.pp_group.is_last_rank: - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + model_inputs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=pp_proxy_tensors, + inputs_embeds=input_embeds, ) - return hidden_states + if self._uses_forward_batch_context: + token = _current_forward_batch.set(forward_batch) + try: + hidden_states = self.model(**model_inputs) + finally: + _current_forward_batch.reset(token) + else: + hidden_states = self.model( + **model_inputs, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + **model_kwargs, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The passed `weights` iterable from sglang is ignored because ATOM @@ -140,9 +226,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # sglang's default weight iterator. from atom.model_loader.loader import load_model_in_plugin_mode - return load_model_in_plugin_mode( - model=self.model, config=self.model.atom_config, prefix="model." - ) + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model_in_plugin_mode( + model=self.model, config=self.atom_config, prefix="model." + ) EntryClass = [] diff --git a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py new file mode 100644 index 000000000..5993624c1 --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py @@ -0,0 +1,199 @@ +"""ATOM DeepSeek NextN wrapper for SGLang external loading. + +This keeps SGLang's draft architecture name (`DeepseekV3ForCausalLMNextN`) +so ModelRegistry can override the upstream implementation, but delegates the +actual draft core to ATOM's `DeepSeekMTP`. +""" + +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args + +from atom.config import SpeculativeConfig +from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( + setup_deepseek_for_sglang, +) +from atom.plugin.sglang.models.base_model_wrapper import ( + _current_forward_batch, + plugin_runtime_scope, +) + +logger = logging.getLogger("atom.plugin.sglang.models") + + +def _sync_replaced_weights() -> None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _replace_weight(module: nn.Module, attr_name: str, weight) -> None: + if hasattr(module, attr_name): + delattr(module, attr_name) + setattr(module, attr_name, weight) + + +def _set_runtime_layer_id(layer_module: nn.Module, layer_id: int) -> None: + if hasattr(layer_module, "layer_id"): + layer_module.layer_id = layer_id + if hasattr(layer_module, "layer_num"): + layer_module.layer_num = layer_id + + +def _retag_mtp_runtime_layer_ids(model: nn.Module) -> None: + """Retag MTP runtime layer ids to draft-local indices. + + ATOM's DeepSeekMTP keeps checkpoint/global layer numbering (e.g. 61, 62...) + in module prefixes so weight remapping still works. SGLang's draft KV cache, + however, allocates layers using draft-local indices (0..num_nextn_layers-1). + Rebind only the runtime ids used by the attention/KV-cache path. + """ + + for local_layer_id, mtp_layer in enumerate(model.model.layers.values()): + mtp_block = mtp_layer.mtp_block + self_attn = mtp_block.self_attn + + _set_runtime_layer_id(self_attn, local_layer_id) + + for attr_name in ("mla_attn", "attn_mha"): + attn_obj = getattr(self_attn, attr_name, None) + if attn_obj is None: + continue + _set_runtime_layer_id(attn_obj, local_layer_id) + nested_attn = getattr(attn_obj, "attn", None) + if nested_attn is not None: + _set_runtime_layer_id(nested_attn, local_layer_id) + + +class DeepseekV3ForCausalLMNextN(nn.Module): + """SGLang-compatible draft wrapper backed by ATOM's `DeepSeekMTP`.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + del prefix + super().__init__() + + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + with plugin_runtime_scope(framework="sglang"): + self.atom_config = generate_atom_config_for_plugin_mode(config) + + # Draft workers need ATOM's MTP-specific config semantics rather than the + # default target-model translation used by the generic plugin wrapper. + SpeculativeConfig.hf_config_override(self.atom_config.hf_config) + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + from atom.plugin.register import ( + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + from atom.models.deepseek_mtp import DeepSeekMTP + + register_ops_to_sglang(atom_config=self.atom_config) + set_attn_cls() + init_aiter_dist(config=self.atom_config) + + self.model = DeepSeekMTP(atom_config=self.atom_config) + self.model.atom_config = self.atom_config + setup_deepseek_for_sglang(self.model) + _retag_mtp_runtime_layer_ids(self.model) + + self.logits_processor = LogitsProcessor(config) + self.lm_head = self._first_mtp_layer().shared_head.head + + def _mtp_layers(self): + return list(self.model.model.layers.values()) + + def _first_mtp_layer(self): + layers = self._mtp_layers() + if not layers: + raise ValueError("DeepSeekMTP does not contain any draft layers") + return layers[0] + + def get_embed_and_head(self): + return self.model.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + self.set_embed(embed) + for mtp_layer in self._mtp_layers(): + _replace_weight(mtp_layer.shared_head.head, "weight", head) + self.lm_head = self._first_mtp_layer().shared_head.head + _sync_replaced_weights() + + def set_embed(self, embed): + _replace_weight(self.model.model.embed_tokens, "weight", embed) + _sync_replaced_weights() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ): + del kwargs + if forward_batch.spec_info is None: + raise ValueError("DeepSeek MTP draft forward requires speculative info") + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + token = _current_forward_batch.set(forward_batch) + try: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + hidden_states=forward_batch.spec_info.hidden_states, + inputs_embeds=input_embeds, + ) + finally: + _current_forward_batch.reset(token) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + del weights + from atom.model_loader.loader import load_model + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + self.atom_config.model = draft_model_path + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model( + model=self.model, + model_name_or_path=draft_model_path, + hf_config=self.atom_config.hf_config, + load_dummy=self.atom_config.load_dummy, + spec_decode=True, + ) + +EntryClass = [DeepseekV3ForCausalLMNextN]