From aa72c643fdc2d2f4f6eacd369a5ccc82fa782065 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 14 Apr 2026 02:19:18 +0000 Subject: [PATCH 1/4] MTP(num_step=1) for DeeepSeek --- atom/plugin/register.py | 15 +- .../attention_backend/radix_attention.py | 17 +- .../attention_backend/sgl_attn_backend.py | 742 ++++++++++++++++-- .../sglang/models/base_model_wrapper.py | 153 +++- .../sglang/models/deepseek_nextn_wrapper.py | 199 +++++ 5 files changed, 1029 insertions(+), 97 deletions(-) create mode 100644 atom/plugin/sglang/models/deepseek_nextn_wrapper.py diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 16853d6ef..d020dd2aa 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -23,21 +23,28 @@ def _register_custom_attention_to_sglang() -> None: sglang only accepts pre-registered backend names, so we reuse the "aiter" name to inject ATOMAttnBackendForSgl without modifying sglang source. """ + import sglang.srt.layers.attention.aiter_backend as sglang_aiter_backend + from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) + from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( + ATOMAttnBackendForSgl, + ) # here register the custom attention backend with the name "aiter" # as sglang defines the fixed attention backend choices, which must be # in-tree logger.info("Register custom attention backend ATOMAttnBackendForSgl to SGLang") + # Speculative draft paths instantiate AiterAttnBackend directly inside + # AiterMultiStepDraftBackend, bypassing the attention registry. Rebind the + # module symbol as well so both registry lookup and direct construction use + # the plugin backend. + sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl + @register_attention_backend("aiter") def create_atom_backend(runner): - from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( - ATOMAttnBackendForSgl, - ) - return ATOMAttnBackendForSgl(runner) diff --git a/atom/plugin/sglang/attention_backend/radix_attention.py b/atom/plugin/sglang/attention_backend/radix_attention.py index 329b719ec..dc822b492 100644 --- a/atom/plugin/sglang/attention_backend/radix_attention.py +++ b/atom/plugin/sglang/attention_backend/radix_attention.py @@ -88,17 +88,20 @@ def __init__( self.attn.k_scale = atom_parameter( torch.tensor([1.0], dtype=torch.float32, device="cuda") ) + elif not self.attn.k_scale.is_cuda: + self.attn.k_scale = torch.nn.Parameter( + self.attn.k_scale.detach().to(device="cuda"), + requires_grad=False, + ) if self.attn.v_scale is None: self.attn.v_scale = atom_parameter( torch.tensor([1.0], dtype=torch.float32, device="cuda") ) - # Some SGLang attention backends consume the host-side float scales - # directly. Keep them in sync with the device-side defaults so the - # plugin path works even when checkpoint loading never populates them. - if self.attn.k_scale_float is None: - self.attn.k_scale_float = 1.0 - if self.attn.v_scale_float is None: - self.attn.v_scale_float = 1.0 + elif not self.attn.v_scale.is_cuda: + self.attn.v_scale = torch.nn.Parameter( + self.attn.v_scale.detach().to(device="cuda"), + requires_grad=False, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py index ad6723570..2a26f8083 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py @@ -20,7 +20,10 @@ import sglang.srt.layers.attention.aiter_backend as _sglang_aiter from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend -from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.attention.utils import ( + create_flashinfer_kv_indices_triton, + pad_sequence_with_mask, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import get_bool_env_var @@ -190,6 +193,7 @@ class ForwardMetadata: reduce_partial_map: Optional[torch.Tensor] = None fp8_prefill_kv_indices: Optional[torch.Tensor] = None num_kv_splits: Optional[int] = None + run_graph: Optional[bool] = True # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) pa_metadata_qo_indptr: Optional[torch.Tensor] = None pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None @@ -283,6 +287,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode_or_idle(): self._init_forward_metadata_decode(forward_batch) + elif self.use_mla and forward_batch.forward_mode.is_draft_extend(): + self._init_draft_extend_mla(forward_batch.batch_size, forward_batch) + elif self.use_mla and forward_batch.forward_mode.is_target_verify(): + self._init_target_verify_mla(forward_batch.batch_size, forward_batch) else: self._init_forward_metadata_extend(forward_batch) self._fixup_page_table(forward_batch) @@ -429,6 +437,188 @@ def _init_forward_metadata_extend(self, forward_batch: ForwardBatch): else: self._init_extend_mha(bs, forward_batch) + def _init_draft_extend_mla(self, bs, forward_batch): + """Init MLA metadata for speculative draft_extend.""" + spec_info = forward_batch.spec_info + if spec_info is None: + raise RuntimeError("MLA draft_extend requires speculative metadata") + + kv_indices, kv_indptr, qo_indptr, _ = spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + self.req_to_token, + ) + + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + if extend_seq_lens_cpu is not None: + max_q_len = ( + int(extend_seq_lens_cpu.max().item()) + if isinstance(extend_seq_lens_cpu, torch.Tensor) + else max(extend_seq_lens_cpu) + ) + elif forward_batch.extend_seq_lens is not None: + max_q_len = int(forward_batch.extend_seq_lens.max().item()) + elif getattr(spec_info, "accept_length", None) is not None: + max_q_len = int(spec_info.accept_length.max().item()) + else: + raise RuntimeError( + "MLA draft_extend is missing extend sequence lengths" + ) + + seq_lens_cpu = forward_batch.seq_lens_cpu + max_kv_len = ( + ( + int(seq_lens_cpu.max().item()) + if isinstance(seq_lens_cpu, torch.Tensor) + else max(seq_lens_cpu) + ) + if seq_lens_cpu is not None + else int(forward_batch.seq_lens.max().item()) + ) + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + self.kv_last_page_len[:bs], + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + max_kv_len, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + ) + + def _init_target_verify_mla(self, bs, forward_batch): + """Init MLA metadata for speculative target_verify.""" + spec_info = forward_batch.spec_info + if spec_info is None: + raise RuntimeError("MLA target_verify requires speculative metadata") + + draft_num = spec_info.draft_token_num + kv_lens = forward_batch.seq_lens + draft_num + kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs + device = forward_batch.seq_lens.device + + qo_indptr = torch.arange( + 0, + (1 + bs) * draft_num, + step=draft_num, + dtype=torch.int32, + device=device, + ) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + kv_lens_sum, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + kv_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(draft_num, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + self.kv_last_page_len[:bs], + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + draft_num, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + draft_num, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + ) + def _init_extend_mla(self, bs, forward_batch): self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, @@ -711,6 +901,11 @@ def init_cuda_graph_state( self.cuda_graph_kv_last_page_len = torch.ones( max_bs, dtype=torch.int, device=self.device ) + assert self.cuda_graph_kv_last_page_len.is_cuda, ( + "ATOMAttnBackendForSgl.init_cuda_graph_state created " + f"non-CUDA cuda_graph_kv_last_page_len on {self.cuda_graph_kv_last_page_len.device}, " + f"backend={type(self)}" + ) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), @@ -838,27 +1033,230 @@ def init_forward_metadata_capture_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], ): - if not forward_mode.is_decode_or_idle(): - raise ValueError(f"Invalid mode: {forward_mode=}") + num_kv_splits = None + work_metadata = None + work_info_set = None + work_indptr = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None - if self.use_mla: - self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) - else: - page_table = self.page_table[:bs, :] - self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) - seq_lens_persistent = self.seq_lens[:bs] - self.forward_metadata = ForwardMetadata( - None, - None, - None, + if forward_mode.is_decode_or_idle(): + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table, + seq_lens_persistent, + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode( + bs, tp_q_head_num=self.num_head + ) + elif forward_mode.is_target_verify(): + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_lens = seq_lens + self.num_draft_tokens if self.use_mla else seq_lens + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + kv_lens, + kv_indptr, None, - 1, + kv_indices, + self.req_to_token.stride(0), + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens + + if self.use_mla: + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph TARGET_VERIFY produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}, metadata_backend={type(self.forward_metadata)}" + ) + else: + custom_mask = self.cuda_graph_custom_mask + assert spec_info is not None and spec_info.custom_mask is not None + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = max_q_len * (seq_lens + max_q_len) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=max_q_len, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph TARGET_VERIFY(non-MLA) produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + step=num_tokens_per_bs, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, None, - page_table, - seq_lens_persistent, + kv_indices, + self.req_to_token.stride(0), ) - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + if self.use_mla: + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "capture_cuda_graph DRAFT_EXTEND produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + None, + num_tokens_per_bs, + None, + None, + None, + custom_mask=None, + mask_indptr=None, + max_extend_len=num_tokens_per_bs, + ) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def init_forward_metadata_replay_cuda_graph( self, @@ -872,40 +1270,240 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, ): - if not forward_mode.is_decode_or_idle(): - raise ValueError("Invalid forward mode") + num_kv_splits = None + work_metadata = None + work_info_set = None + work_indptr = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None - if self.use_mla: - self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) - else: - page_table_persistent = self.page_table - seq_lens_persistent = self.seq_lens - seq_lens_persistent.fill_(0) - page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) - max_seq_pages = ( - seq_lens_cpu.max().item() + self.page_size - 1 - ) // self.page_size + 1 - page_table = self.req_to_token[ - req_pool_indices[:, None], - self.strided_indices[:max_seq_pages][None, :], - ] - page_table_persistent[:bs, :max_seq_pages].copy_( - page_table // self.page_size, non_blocking=True - ) + if forward_mode.is_decode_or_idle(): + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) - self.forward_metadata = ForwardMetadata( - None, - None, - None, + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode( + bs, tp_q_head_num=self.num_head + ) + elif forward_mode.is_target_verify(): + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_lens = seq_lens + self.num_draft_tokens if self.use_mla else seq_lens + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + kv_lens, + kv_indptr, None, - 1, + kv_indices, + self.req_to_token.stride(0), + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = self.num_draft_tokens + + if self.use_mla: + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph TARGET_VERIFY produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}, metadata_backend={type(self.forward_metadata)}" + ) + else: + custom_mask = self.cuda_graph_custom_mask + assert spec_info is not None and spec_info.custom_mask is not None + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = max_q_len * (seq_lens + max_q_len) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=max_q_len, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph TARGET_VERIFY(non-MLA) produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + seq_lens = seq_lens[:bs] + accept_lens = spec_info.accept_length[:bs] + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, None, - page_table_persistent[:bs, :max_seq_pages], - seq_lens_persistent[:bs], + kv_indices, + self.req_to_token.stride(0), ) - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + if self.use_mla: + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + assert ( + self.forward_metadata.kv_last_page_len is None + or self.forward_metadata.kv_last_page_len.is_cuda + ), ( + "replay_cuda_graph DRAFT_EXTEND produced non-CUDA kv_last_page_len: " + f"{self.forward_metadata.kv_last_page_len.device}, " + f"backend={type(self)}" + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + None, + num_tokens_per_bs, + None, + None, + None, + custom_mask=None, + mask_indptr=None, + max_extend_len=num_tokens_per_bs, + ) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def set_kv_buffer_with_layout_shuffle( self, @@ -1019,6 +1617,7 @@ def _forward_extend_mla(self, q, k, v, layer, forward_batch): layer, K_Buffer, qo_indptr, + forward_batch, ) if not forward_batch.forward_mode.is_extend(): raise ValueError( @@ -1356,14 +1955,51 @@ def _call_mla_decode_fwd(self, q, k_buffer, o, layer): num_kv_splits=md.num_kv_splits, ) - def _forward_extend_mla_speculative(self, q, layer, K_Buffer, qo_indptr): + def _forward_extend_mla_speculative( + self, q, layer, K_Buffer, qo_indptr, forward_batch + ): """MLA speculative path (target_verify / draft_extend).""" - o = q.new_empty( - (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), - dtype=self.input_dtype, + md = self.forward_metadata + + if forward_batch.forward_mode.is_target_verify(): + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + return o + + if forward_batch.forward_mode.is_draft_extend(): + if md.run_graph is not True: + bs, q_pad, _ = pad_sequence_with_mask( + q.view(q.shape[0], -1), + qo_indptr[:-1], + forward_batch.extend_seq_lens, + md.max_q_len, + ) + o = q.new_empty( + (bs * md.max_q_len, layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd( + q_pad.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + K_Buffer, + o, + layer, + ) + total_valid_q = int(qo_indptr[-1].item()) + return o[:total_valid_q] + + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + return o + + raise ValueError( + f"Invalid forward mode for MLA speculative path: {forward_batch.forward_mode=}" ) - self._call_mla_decode_fwd(q, K_Buffer, o, layer) - return o def forward_decode( self, diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 49ceacaf1..9182e0fa6 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -7,6 +7,7 @@ """ import logging +from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Iterable, Optional, Tuple, Union @@ -20,6 +21,8 @@ logger = logging.getLogger("atom.plugin.sglang.models") +_RUNTIME_SENTINEL = object() + # Context for patched DeepSeek attention layers that need wrapper state without # changing every intermediate forward signature. ContextVar keeps nested or # concurrent forwards isolated and lets us reliably restore the prior value. @@ -32,6 +35,37 @@ def get_current_forward_batch(): return _current_forward_batch.get() +@contextmanager +def plugin_runtime_scope( + *, + framework: Optional[str] = None, + atom_config: Any = _RUNTIME_SENTINEL, +): + """Temporarily bind plugin runtime globals to one wrapper instance. + + ATOM core currently relies on process-global framework/config state. In + SGLang speculative mode both target and draft wrappers coexist, so plugin + entrypoints must save/restore those globals around each init/load/forward. + """ + + import atom.config as atom_config_module + import atom.plugin.prepare as plugin_prepare + + prev_framework = plugin_prepare._CURRENT_FRAMEWORK + prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) + + if framework is not None: + plugin_prepare._set_framework_backbone(framework) + if atom_config is not _RUNTIME_SENTINEL: + atom_config_module._current_atom_config = atom_config + + try: + yield + finally: + plugin_prepare._CURRENT_FRAMEWORK = prev_framework + atom_config_module._current_atom_config = prev_atom_config + + _MODEL_NAMES = [ "DeepseekV3ForCausalLM", "Qwen3MoeForCausalLM", @@ -72,7 +106,14 @@ def __init__( # Refactor so this wrapper only dispatches the attention backend # (register_ops_to_sglang + set_attn_cls), and let sglang handle # model construction directly - self.model = atom.prepare_model(config=config, engine="sglang") + with plugin_runtime_scope(framework="sglang"): + from atom.config import get_current_atom_config + + self.model = atom.prepare_model(config=config, engine="sglang") + self.atom_config = getattr(self.model, "atom_config", None) + if self.atom_config is None: + self.atom_config = get_current_atom_config() + self.model.atom_config = self.atom_config if self.model is None: model_arch = getattr(config, "architectures", ["unknown"])[0] raise ValueError( @@ -90,7 +131,51 @@ def __init__( setup_deepseek_for_sglang, ) - setup_deepseek_for_sglang(self.model) + with plugin_runtime_scope( + framework="sglang", atom_config=self.atom_config + ): + setup_deepseek_for_sglang(self.model) + + def get_embed_and_head(self): + if hasattr(self.model, "get_embed_and_head"): + return self.model.get_embed_and_head() + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + return embed_owner.embed_tokens.weight, self.model.lm_head.weight + + def set_embed_and_head(self, embed, head): + if hasattr(self.model, "set_embed_and_head"): + return self.model.set_embed_and_head(embed, head) + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + del embed_owner.embed_tokens.weight + del self.model.lm_head.weight + embed_owner.embed_tokens.weight = embed + self.model.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def set_embed(self, embed): + if hasattr(self.model, "set_embed"): + return self.model.set_embed(embed) + + embed_owner = ( + self.model.model + if hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens") + else self.model + ) + del embed_owner.embed_tokens.weight + embed_owner.embed_tokens.weight = embed + torch.cuda.empty_cache() + torch.cuda.synchronize() @torch.no_grad() def forward( @@ -103,35 +188,36 @@ def forward( pp_proxy_tensors: Optional[PPProxyTensors] = None, **model_kwargs: Any, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - model_inputs = dict( - input_ids=input_ids, - positions=positions, - intermediate_tensors=pp_proxy_tensors, - inputs_embeds=input_embeds, - ) - if self._uses_forward_batch_context: - token = _current_forward_batch.set(forward_batch) - try: - hidden_states = self.model(**model_inputs) - finally: - _current_forward_batch.reset(token) - else: - hidden_states = self.model( - **model_inputs, - forward_batch=forward_batch, - get_embedding=get_embedding, - pp_proxy_tensors=pp_proxy_tensors, - **model_kwargs, - ) - - if self.pp_group.is_last_rank: - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + model_inputs = dict( + input_ids=input_ids, + positions=positions, + intermediate_tensors=pp_proxy_tensors, + inputs_embeds=input_embeds, ) - return hidden_states + if self._uses_forward_batch_context: + token = _current_forward_batch.set(forward_batch) + try: + hidden_states = self.model(**model_inputs) + finally: + _current_forward_batch.reset(token) + else: + hidden_states = self.model( + **model_inputs, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + **model_kwargs, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The passed `weights` iterable from sglang is ignored because ATOM @@ -140,9 +226,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # sglang's default weight iterator. from atom.model_loader.loader import load_model_in_plugin_mode - return load_model_in_plugin_mode( - model=self.model, config=self.model.atom_config, prefix="model." - ) + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model_in_plugin_mode( + model=self.model, config=self.atom_config, prefix="model." + ) EntryClass = [] diff --git a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py new file mode 100644 index 000000000..5993624c1 --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py @@ -0,0 +1,199 @@ +"""ATOM DeepSeek NextN wrapper for SGLang external loading. + +This keeps SGLang's draft architecture name (`DeepseekV3ForCausalLMNextN`) +so ModelRegistry can override the upstream implementation, but delegates the +actual draft core to ATOM's `DeepSeekMTP`. +""" + +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args + +from atom.config import SpeculativeConfig +from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( + setup_deepseek_for_sglang, +) +from atom.plugin.sglang.models.base_model_wrapper import ( + _current_forward_batch, + plugin_runtime_scope, +) + +logger = logging.getLogger("atom.plugin.sglang.models") + + +def _sync_replaced_weights() -> None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _replace_weight(module: nn.Module, attr_name: str, weight) -> None: + if hasattr(module, attr_name): + delattr(module, attr_name) + setattr(module, attr_name, weight) + + +def _set_runtime_layer_id(layer_module: nn.Module, layer_id: int) -> None: + if hasattr(layer_module, "layer_id"): + layer_module.layer_id = layer_id + if hasattr(layer_module, "layer_num"): + layer_module.layer_num = layer_id + + +def _retag_mtp_runtime_layer_ids(model: nn.Module) -> None: + """Retag MTP runtime layer ids to draft-local indices. + + ATOM's DeepSeekMTP keeps checkpoint/global layer numbering (e.g. 61, 62...) + in module prefixes so weight remapping still works. SGLang's draft KV cache, + however, allocates layers using draft-local indices (0..num_nextn_layers-1). + Rebind only the runtime ids used by the attention/KV-cache path. + """ + + for local_layer_id, mtp_layer in enumerate(model.model.layers.values()): + mtp_block = mtp_layer.mtp_block + self_attn = mtp_block.self_attn + + _set_runtime_layer_id(self_attn, local_layer_id) + + for attr_name in ("mla_attn", "attn_mha"): + attn_obj = getattr(self_attn, attr_name, None) + if attn_obj is None: + continue + _set_runtime_layer_id(attn_obj, local_layer_id) + nested_attn = getattr(attn_obj, "attn", None) + if nested_attn is not None: + _set_runtime_layer_id(nested_attn, local_layer_id) + + +class DeepseekV3ForCausalLMNextN(nn.Module): + """SGLang-compatible draft wrapper backed by ATOM's `DeepSeekMTP`.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + del prefix + super().__init__() + + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + with plugin_runtime_scope(framework="sglang"): + self.atom_config = generate_atom_config_for_plugin_mode(config) + + # Draft workers need ATOM's MTP-specific config semantics rather than the + # default target-model translation used by the generic plugin wrapper. + SpeculativeConfig.hf_config_override(self.atom_config.hf_config) + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + from atom.plugin.register import ( + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + from atom.models.deepseek_mtp import DeepSeekMTP + + register_ops_to_sglang(atom_config=self.atom_config) + set_attn_cls() + init_aiter_dist(config=self.atom_config) + + self.model = DeepSeekMTP(atom_config=self.atom_config) + self.model.atom_config = self.atom_config + setup_deepseek_for_sglang(self.model) + _retag_mtp_runtime_layer_ids(self.model) + + self.logits_processor = LogitsProcessor(config) + self.lm_head = self._first_mtp_layer().shared_head.head + + def _mtp_layers(self): + return list(self.model.model.layers.values()) + + def _first_mtp_layer(self): + layers = self._mtp_layers() + if not layers: + raise ValueError("DeepSeekMTP does not contain any draft layers") + return layers[0] + + def get_embed_and_head(self): + return self.model.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + self.set_embed(embed) + for mtp_layer in self._mtp_layers(): + _replace_weight(mtp_layer.shared_head.head, "weight", head) + self.lm_head = self._first_mtp_layer().shared_head.head + _sync_replaced_weights() + + def set_embed(self, embed): + _replace_weight(self.model.model.embed_tokens, "weight", embed) + _sync_replaced_weights() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ): + del kwargs + if forward_batch.spec_info is None: + raise ValueError("DeepSeek MTP draft forward requires speculative info") + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + token = _current_forward_batch.set(forward_batch) + try: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + hidden_states=forward_batch.spec_info.hidden_states, + inputs_embeds=input_embeds, + ) + finally: + _current_forward_batch.reset(token) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + del weights + from atom.model_loader.loader import load_model + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + self.atom_config.model = draft_model_path + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model( + model=self.model, + model_name_or_path=draft_model_path, + hf_config=self.atom_config.hf_config, + load_dummy=self.atom_config.load_dummy, + spec_decode=True, + ) + +EntryClass = [DeepseekV3ForCausalLMNextN] From 97cb0eb558aca169414b72ee7f01e74c99bc104c Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 14 Apr 2026 08:19:42 +0000 Subject: [PATCH 2/4] Add work log for claude debug --- ...deepseek-speculative-attention-metadata.md | 721 ++++++++++ ...8-sglang-attention-backend-fields-guide.md | 908 ++++++++++++ ...026-04-08-sglang-kv-cache-storage-guide.md | 692 ++++++++++ ...glang-speculative-decoding-architecture.md | 910 ++++++++++++ .../2026-04-08-vllm-continuous-batching.md | 1214 +++++++++++++++++ ...cudagraph-prefill-decode-metadata-guide.md | 734 ++++++++++ ...simple-prefill-cudagraph-metadata-guide.md | 866 ++++++++++++ work_log/MTP/MTP-2026-04-08.md | 525 +++++++ work_log/MTP/MTP-2026-04-09.md | 715 ++++++++++ work_log/MTP/MTP-2026-04-10.md | 801 +++++++++++ 10 files changed, 8086 insertions(+) create mode 100644 work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md create mode 100644 work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md create mode 100644 work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md create mode 100644 work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md create mode 100644 work_log/MTP/2026-04-08-vllm-continuous-batching.md create mode 100644 work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md create mode 100644 work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md create mode 100644 work_log/MTP/MTP-2026-04-08.md create mode 100644 work_log/MTP/MTP-2026-04-09.md create mode 100644 work_log/MTP/MTP-2026-04-10.md diff --git a/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md b/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md new file mode 100644 index 000000000..bc2e136a6 --- /dev/null +++ b/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md @@ -0,0 +1,721 @@ +# 2026-04-08 DeepSeek Speculative 与 Attention Metadata 关系笔记 + +## 文档目的 + +本文专门解释一个在调试 DeepSeek speculative / MTP 时非常关键、但又很容易被低估的问题: + +- **speculative decoding 和 attention metadata 到底是什么关系?** + +对于 DeepSeek 这类 MLA 模型来说,这个问题尤其重要。因为 speculative decoding +不是简单地“多跑一个 draft model”,它会直接改变: + +- 当前 batch 有多少 query token +- 每个 query 应该看到哪些 KV +- 这些 KV 在 paged KV cache 中的索引方式 +- 是否需要树状 mask / causal mask +- MLA kernel 需要的 workspace / split / persistent metadata + +换句话说: + +**speculative decoding 在 runtime 层的本质,就是不断重写 attention metadata。** + + +## 一句话理解 + +可以把 attention metadata 理解为: + +- “这一次 attention 要怎么看 KV cache”的说明书 + +而 speculative decoding 做的事情,本质上就是不断改变这份说明书: + +- normal decode:每个请求 1 个 query,查自己已有上下文 +- draft extend:一次要处理多个 draft token,需要新的 `qo_indptr` 与 mask +- target verify:要同时验证多个候选 token,query 形状和 KV 长度都变了 +- DeepSeek MLA / MTP:`max_q_len`、`kv_indptr`、`qo_indptr`、`work_metadata` + 会直接影响 kernel 如何执行 + + +## 1. 为什么这个问题在 DeepSeek 上特别重要 + +DeepSeek 使用 MLA(Multi-head Latent Attention)后,attention metadata 的作用比普通 +MHA 更重: + +- 普通 MHA 更多是 query/key/value 张量形状和 mask 变化 +- MLA 还要额外构造: + - `kv_indptr` + - `kv_indices` + - `qo_indptr` + - `kv_last_page_len` + - `max_q_len` + - `work_metadata` + - `work_info_set` + - `reduce_indptr` + - `reduce_final_map` + - `reduce_partial_map` + +这些量直接决定: + +- MLA persistent kernel 如何分块 +- 每个 query 要从 paged KV cache 里取哪些 token +- multi-query(例如 verify / MTP)时 query 维度如何展开 + +所以在 DeepSeek speculative 路径中,真正最敏感的往往不是 model forward 本身, +而是 **attention metadata 是否按正确语义构出来**。 + + +## 2. 先看三层 batch 抽象 + +核心文件: + +- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` +- `sglang/python/sglang/srt/managers/schedule_batch.py` + +SGLang 中 batch 数据结构有三层: + +- `ScheduleBatch` +- `ModelWorkerBatch` +- `ForwardBatch` + +源码注释位置: + +- `forward_batch_info.py` 文件开头 + +这三层的职责可以粗略理解为: + +- `ScheduleBatch` + - scheduler 视角 + - 关注请求、prefix、token、调度状态 +- `ModelWorkerBatch` + - worker 视角 + - 关注一次 GPU forward 所需字段 +- `ForwardBatch` + - backend / kernel 视角 + - 关注 query、KV、cache、metadata + +**attention metadata 的最终落点是在 `ForwardBatch -> attn_backend.init_forward_metadata()`** +这一层。 + + +## 3. `ForwardMode`:speculative 如何改变 metadata 初始化分支 + +核心文件: + +- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` + +关键枚举: + +- `ForwardMode` + +关键 speculative mode: + +- `TARGET_VERIFY` +- `DRAFT_EXTEND` +- `DRAFT_EXTEND_V2` + +关键代码位置: + +- `ForwardMode` 定义:约 `74-179` + +最容易踩坑的一点: + +- `ForwardMode.is_extend()` 会把 `TARGET_VERIFY` 也算进去 + +对应逻辑: + +- `forward_batch_info.py` 约 `105-114` + +这意味着如果某个 backend 只是粗暴地区分: + +- decode +- extend + +而没有再细分: + +- target_verify +- draft_extend + +那么它很容易把 verify 当普通 extend 处理,然后在 metadata 上出错。 + + +## 4. speculative 信息是如何进入 attention 层的 + +核心抽象: + +- `SpecInput` + +文件: + +- `sglang/python/sglang/srt/speculative/spec_info.py` + +关键点: + +- `SpecInput` 不是附带信息,而是 speculative 与 attention metadata 的桥梁 +- 它负责携带: + - speculative token 相关信息 + - 需要的 positions + - `kv_indptr` / `kv_indices` + - `custom_mask` + - `accept_length` + - `draft_token_num` + - 其他草稿 / 验证所需状态 + +相关位置: + +- `SpecInputType`:约 `108-113` +- `SpecInput`:约 `116-143` + +这里有一个很重要的方法: + +- `get_spec_adjusted_global_num_tokens()` + +它说明 speculative decoding 会直接改变: + +- global num tokens +- logprob token 数 + +这也间接影响 batch padding 和后续 metadata 构造。 + + +## 5. `EagleVerifyInput`:speculative 到 metadata 的第一层接口 + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_info.py` + +关键类: + +- `EagleVerifyInput` + +关键字段: + +- `draft_token` +- `custom_mask` +- `positions` +- `draft_token_num` +- `capture_hidden_mode` +- `seq_lens_sum` +- `seq_lens_cpu` + +代码位置: + +- `eagle_info.py` 约 `54-78` + +这些字段本身就已经说明了 speculative 和 metadata 的关系: + +- `draft_token` + - 决定 verify 阶段实际送入 target 的 query token +- `positions` + - 决定 RoPE / position indexing +- `draft_token_num` + - 决定一次 verify 需要几个 query +- `custom_mask` + - 决定树状 speculative 验证时的可见性 + + +## 6. verify 阶段:speculative 如何改写 batch + +### 6.1 v1 路径 + +关键文件: + +- `sglang/python/sglang/srt/speculative/eagle_worker.py` +- `sglang/python/sglang/srt/speculative/eagle_info.py` + +关键流程: + +1. `draft()` 先生成候选 token,形成 `EagleVerifyInput` +2. `verify()` 调 `spec_info.prepare_for_verify(batch, page_size)` +3. `batch.forward_mode` 被改成 `TARGET_VERIFY` +4. target worker 执行 verify forward + +关键位置: + +- `eagle_worker.py` 中 `verify()`:约 `699-788` +- `eagle_info.py` 中 `prepare_for_verify()`:约 `104-146` + + +### 6.2 `prepare_for_verify()` 改了什么 + +它主要会做: + +- `batch.input_ids = self.draft_token` +- 分配 `batch.out_cache_loc` +- 更新 `req_to_token_pool` + +也就是: + +- target verify 不再看原先的“普通 decode 单 token 输入” +- 而是把所有 draft token 当作本轮 query 批次 + +这已经说明: + +- verify 不是普通 decode +- verify 的 query 形状和 KV 形状都变了 +- 所以 attention metadata 必须重新构造 + + +## 7. `generate_attn_arg_prefill()`:draft_extend 的 metadata 生成器 + +文件: + +- `sglang/python/sglang/srt/speculative/eagle_info.py` + +关键方法: + +- `generate_attn_arg_prefill()` + +代码位置: + +- 约 `160-216` + +这个函数非常关键,因为它直接把 speculative 信息翻译成 attention metadata 里的核心索引: + +- `qo_indptr` +- `cum_kv_seq_len`(本质上就是 `kv_indptr`) +- `kv_indices` +- `custom_mask` + +可以理解为: + +- speculative 输入先描述“我要验证/扩展多少个 draft token、树结构是什么” +- `generate_attn_arg_prefill()` 再把这种高层语义翻译成 kernel 能消费的索引格式 + + +### 7.1 `qo_indptr` 是什么 + +在这里: + +- `qo_indptr` 表示 query output token 在 batch 中如何分段 + +例如: + +- 每个请求有 `draft_token_num` 个 query +- 那么 `qo_indptr` 就会按这个 query 数量分桶 + + +### 7.2 `kv_indptr` / `cum_kv_seq_len` 是什么 + +它表示: + +- 每个请求在当前 forward 中可见的 KV token 范围 + +draft / verify 会把: + +- 原始 `paged_kernel_lens` + +扩成: + +- `paged_kernel_lens + draft_token_num` + +这说明 speculative decoding 不是只“多几个 query”,而是连本轮可见 KV 长度都变了。 + + +### 7.3 `custom_mask` 是什么 + +对于树状 speculative decode: + +- 不是所有 draft token 都能互相看见 + +所以需要: + +- `custom_mask` + +来表示 tree-based causal structure。 + +这个量会在非 MLA MHA 路径里更直接地进入 attention kernel。 + + +## 8. v2 路径:`prepare_for_v2_verify()` 如何构造 verify metadata + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_info_v2.py` + +关键方法: + +- `prepare_for_v2_verify()` + +代码位置: + +- `eagle_info_v2.py` 约 `213-270` + +这个方法做的事情可以理解成: + +1. 先按 speculative verify 语义设置: + - `batch.input_ids` + - `batch.out_cache_loc` +2. 把 `batch.forward_mode` 改成 `TARGET_VERIFY` +3. 通过 `ForwardBatch.init_new(batch, target_worker.model_runner)` + 得到真正的 `ForwardBatch` +4. 然后显式调用: + - `target_worker.model_runner.attn_backend.init_forward_metadata(verify_forward_batch)` + +这说明: + +- speculative verify 到 attention metadata 的连接点,不是在 model forward 里隐式发生的 +- 而是在 `prepare_for_v2_verify()` 中显式发生的 + + +## 9. attention metadata 长什么样 + +### 9.1 upstream SGLang `ForwardMetadata` + +文件: + +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + +关键 dataclass: + +- `ForwardMetadata` + +代码位置: + +- `aiter_backend.py` 约 `76-95` + +关键字段: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `max_q_len` +- `max_kv_len` +- `work_metadata` +- `work_info_set` +- `reduce_indptr` +- `reduce_final_map` +- `reduce_partial_map` +- `num_kv_splits` +- `custom_mask` +- `mask_indptr` +- `max_extend_len` +- `fp8_prefill_kv_indices` + + +### 9.2 ATOM plugin 的 `ForwardMetadata` + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +关键 dataclass: + +- `ForwardMetadata` + +代码位置: + +- `sgl_attn_backend.py` 约 `171-198` + +从字段上看,ATOM plugin 其实已经承认 speculative / MLA attention 需要这些索引和 workspace。 +所以当前问题不是“不知道这些量存在”,而是: + +- 没在 metadata init 分支上完全按 upstream 语义实现 + + +## 10. upstream `AiterAttnBackend.init_forward_metadata()` 如何按 speculative 分流 + +核心文件: + +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + +关键方法: + +- `init_forward_metadata()` + +代码位置: + +- `aiter_backend.py` 约 `435-684` + +这是整条链最重要的代码之一。 + +它不是简单分成: + +- decode +- extend + +而是分成: + +1. `decode_or_idle` +2. `draft_extend` +3. `target_verify` +4. 普通 extend + + +### 10.1 普通 decode / idle + +逻辑: + +- `spec_info` 为空时,按普通 decode 构造 +- `spec_info` 不为空时,直接复用 `spec_info.kv_indptr / kv_indices` + +这说明 speculative 已经开始介入 decode metadata。 + + +### 10.2 `draft_extend` + +逻辑: + +- 调 `spec_info.generate_attn_arg_prefill()` +- 拿到: + - `kv_indices` + - `kv_indptr` + - `qo_indptr` + - `custom_mask` +- MLA 路再进一步根据 `extend_seq_lens_cpu` + 计算 `max_seqlen_qo` 和 persistent kernel metadata + +关键位置: + +- `aiter_backend.py` 约 `526-606` + + +### 10.3 `target_verify` + +这是最关键的一支。 + +逻辑: + +- 不依赖普通 extend 的 `extend_seq_lens` +- 直接用: + - `draft_num = spec_info.draft_token_num` + - `kv_lens = forward_batch.seq_lens + draft_num` +- 自己构造: + - `qo_indptr` + - `kv_indptr` + - `kv_indices` +- 对 MLA 路: + - `max_q_len = draft_num` + +关键位置: + +- `aiter_backend.py` 约 `607-684` + +这个分支完美说明: + +**verify 不是普通 extend,speculative 会直接重定义 query 长度和 KV 长度。** + + +## 11. DeepSeek MLA:为什么 speculative 更像“metadata 问题” + +对于 DeepSeek MLA 来说,attention forward 真正吃的是: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `max_q_len` +- `work_metadata` / `reduce_*` + +如果这些量不对: + +- 哪怕 model forward、q/k/v 张量本身都没问题 +- kernel 也会在错误的 KV 范围上工作 + +这就是为什么调试 speculative 时,attention metadata 的正确性往往比 model 本身更先决定成败。 + + +## 12. 本次调试得到的一个关键教训 + +在 `ATOM plugin` 当前实现中: + +- `sgl_attn_backend.py` 里的 `_forward_extend_mla()` 已经认识: + - `TARGET_VERIFY` + - `DRAFT_EXTEND` + +代码位置: + +- `sgl_attn_backend.py` 约 `1001-1022` + +但 metadata 初始化层还没有完全按 upstream 分支细化: + +- `init_forward_metadata()` 仍是: + - `decode_or_idle` + - else -> `extend` + +代码位置: + +- `sgl_attn_backend.py` 约 `282-288` + +于是 `TARGET_VERIFY` 会被误送进普通 `_init_extend_mla()`: + +- 它会错误假设 `forward_batch.extend_seq_lens` 一定存在 + +而在 verify 路径下: + +- `extend_seq_lens` 本来就可能是 `None` + +这就是为什么当前错误看起来像: + +- `NoneType has no attribute max` + +实际上本质是: + +- **speculative 和 attention metadata 的语义没有对齐** + + +## 13. 从 ATOM 原生 MTP 再看一次 metadata 的重要性 + +如果看 ATOM 原生链路: + +- `ATOM/atom/spec_decode/eagle.py` +- `ATOM/atom/model_ops/attentions/aiter_mla.py` + +会发现 speculative / MTP 对 attention metadata 的耦合更直接。 + +### 13.1 `EagleProposer.propose()` + +关键位置: + +- `atom/spec_decode/eagle.py` 约 `94-190` + +在多步 draft 过程中,会不断更新: + +- `attn_metadata.max_seqlen_q` +- `attn_metadata.max_seqlen_k` +- `kv_indptr` +- `kv_indices` +- `cu_seqlens_q` +- `slot_mapping` +- `kv_last_page_lens` + +并调用: + +- `prepare_mtp_decode()` + +这说明在 ATOM 原生实现里: + +- speculative 不是 attention 上的一点点附加参数 +- 而是会不断重写 attention metadata + + +### 13.2 `prepare_mtp_decode()` + +文件: + +- `ATOM/atom/model_ops/attentions/aiter_mla.py` + +关键位置: + +- `prepare_mtp_decode()`:约 `225-250` + +作用: + +- 为多 token 预测构造 MTP decode 需要的 KV / worker metadata + +同文件里还有一个重要信号: + +- `prepare_decode()` 会在有 drafter 时把 + `max_seqlen_q = drafter.mtp_k + 1` + +位置: + +- `aiter_mla.py` 约 `352-357` + +这再次说明: + +- speculative / MTP 本质上会改变 query 维度 +- query 维度一变,attention metadata 就必须重建 + + +## 14. 调试 speculative + metadata 时的实用检查表 + +如果后续继续调试 DeepSeek speculative / MTP,建议优先检查下面几项: + +### 1. 当前 `ForwardMode` 是什么 + +看: + +- `decode` +- `target_verify` +- `draft_extend` +- `draft_extend_v2` + +如果 mode 判断错了,metadata 分支通常也会错。 + + +### 2. 当前 `spec_info` 是不是空 + +如果 `spec_info` 不为空,就不应该再走普通 extend 的 metadata 逻辑。 + + +### 3. `qo_indptr` 是否和 speculative token 数一致 + +例如 verify 路径里: + +- `max_q_len` 应该接近 `draft_token_num` + +而不是普通 decode 的 `1`。 + + +### 4. `kv_indptr / kv_indices` 是否按 speculative 后的新 KV 长度构造 + +verify 阶段一般应当看到: + +- `kv_lens = seq_lens + draft_token_num` + +而不是原始 `seq_lens`。 + + +### 5. 是否错误依赖了 `extend_seq_lens` + +普通 extend 可以依赖: + +- `extend_seq_lens` + +但 `target_verify` 不应简单照搬这套假设。 + + +### 6. 是否需要 `custom_mask` + +树状 speculative / topk 路径下: + +- `custom_mask` + +常常是必须的;它缺失时可能不会立刻报错,但结果会错。 + + +## 15. 推荐阅读顺序 + +如果以后要重新从头搞清楚 “DeepSeek speculative 与 attention metadata 的关系”, +推荐按下面顺序阅读: + +1. `sglang/python/sglang/srt/model_executor/forward_batch_info.py` + - 看 `ForwardMode` +2. `sglang/python/sglang/srt/speculative/spec_info.py` + - 看 `SpecInput` +3. `sglang/python/sglang/srt/speculative/eagle_info.py` + - 看 `EagleVerifyInput` +4. `sglang/python/sglang/srt/speculative/eagle_info_v2.py` + - 看 `prepare_for_v2_verify()` +5. `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + - 看 `init_forward_metadata()` 的四种 speculative 分支 +6. `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + - 对照 plugin 当前实现和 upstream 差异 +7. `ATOM/atom/spec_decode/eagle.py` + - 看 ATOM 原生 speculative 是怎么驱动 attn metadata 更新的 +8. `ATOM/atom/model_ops/attentions/aiter_mla.py` + - 看 MTP decode 的 metadata 准备逻辑 + + +## 16. 最终总结 + +对于 DeepSeek 而言: + +- speculative decoding 的重点不只是 draft model +- attention metadata 才是把 speculative 语义真正落到 kernel 的关键层 + +可以用下面一句话概括: + +**draft / verify 负责决定“要处理哪些 token”,attention metadata 负责把这个决定变成 kernel 可执行的 KV / Q 索引和 workspace 说明。** + +因此,后续如果要在 `ATOM + SGLang plugin` 路径真正接通 DeepSeek MTP, +核心工作并不是只把 draft model 换成 `ATOM DeepSeekMTP`,还包括: + +- 让 plugin 的 attention metadata 初始化完整理解 + - `TARGET_VERIFY` + - `DRAFT_EXTEND` + - `MTP 多 query` + - `custom_mask` + - `qo_indptr / kv_indptr / kv_indices` + +只有这层语义也打通,DeepSeek speculative 才算真正可用。 diff --git a/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md b/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md new file mode 100644 index 000000000..b7c170af7 --- /dev/null +++ b/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md @@ -0,0 +1,908 @@ +# SGLang Attention Backend 字段说明 + +## 文档目的 + +这篇文档专门解释 `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +里 `ForwardMetadata` 的核心字段,重点放在: + +- 这些字段在调度链路里是怎么来的 +- 它们分别表示什么语义 +- 它们的 shape 是什么 +- 它们和 SGLang 的 KV cache 存储结构是什么关系 + +本文刻意**不展开** `reduce_indptr`、`reduce_final_map`、`reduce_partial_map` +这类更偏 kernel 内部 workspace 的字段,只在必要时顺带提一句。 + + +## 一句话理解 + +可以把 `ForwardMetadata` 理解成: + +- `scheduler / ForwardBatch` 已经决定了“这一步要算哪些 request、每个 request 算多少 query、这些 query 应该看到哪些 KV” +- `attn_backend.init_forward_metadata()` 负责把这个高层语义,转换成 attention kernel 真正能消费的低层索引 + +其中最关键的就是三类信息: + +- **Q 侧分段信息**:`qo_indptr`, `max_q_len` +- **KV 侧分段信息**:`kv_indptr`, `kv_indices`, `kv_last_page_len`, `max_kv_len` +- **非 MLA 下的 page 化信息**:`page_table`, `kv_lens` + + +## 1. 三层 batch 抽象 + +SGLang 里和 attention metadata 直接相关的 batch 抽象有三层: + +- `ScheduleBatch` +- `ModelWorkerBatch` +- `ForwardBatch` + +其中: + +- `ScheduleBatch` + - scheduler 视角 + - 关心请求、prefix、seq len、cache slot 分配 +- `ModelWorkerBatch` + - worker 视角 + - 是一次 GPU forward 所需字段的中间态 +- `ForwardBatch` + - attention backend / kernel 视角 + - 大部分字段已经是 GPU tensor + +可以粗略画成: + +```mermaid +flowchart LR + A[Scheduler / ScheduleBatch] + B[ModelWorkerBatch] + C[ForwardBatch] + D[init_forward_metadata] + E[ForwardMetadata] + F[Attention Kernel] + + A --> B --> C --> D --> E --> F +``` + +`ForwardMetadata` 就是 `ForwardBatch` 再往下走一步,把“批次语义”翻译成“索引语义”的结果。 + + +## 2. 调度到 metadata 的主链路 + +最值得记住的链路是: + +1. scheduler 决定这一步 batch 里有哪些 request +2. scheduler 为这些 request 分配或复用 KV slot +3. `ScheduleBatch.get_model_worker_batch()` 把调度状态打包 +4. `ForwardBatch.init_new()` 把 CPU 侧 list / 状态变成 GPU tensor +5. `attn_backend.init_forward_metadata()` 生成 `ForwardMetadata` + +对应几个关键字段来源如下: + +- `req_pool_indices` + - 来自 `ScheduleBatch` + - 表示每个 request 在 `ReqToTokenPool.req_to_token` 里的“行号” +- `seq_lens` + - 每个 request 当前参与 attention 的 KV 长度 +- `out_cache_loc` + - 本轮新 token 写入 KV cache 的物理 slot +- `extend_seq_lens` + - prefill / extend 时,每个 request 本轮真正新增了多少 query token +- `spec_info` + - speculative 路径下额外提供 verify / draft_extend 需要的 query 结构 + + +## 3. 先看 KV cache 的两层存储 + +理解 `kv_indptr` / `kv_indices` 之前,必须先看清 SGLang 的 KV cache 存储不是“一块连续上下文”,而是两层映射: + +- `ReqToTokenPool` +- `TokenToKVPool` + +### 3.1 `ReqToTokenPool` + +文件: + +- `sglang/python/sglang/srt/mem_cache/memory_pool.py` + +核心张量: + +- `req_to_token` + +shape: + +- `[req_pool_size, max_context_len]` +- dtype 通常是 `int32` + +语义: + +- 行:一个 request slot +- 列:这个 request 的逻辑 token 位置 +- 值:该位置对应的 **物理 KV slot id** + +也就是说,`req_to_token` 不是存 K/V 本身,而是存: + +- `request 的第 i 个 token,实际写到了 token_to_kv_pool 的哪个 slot` + +可以理解成: + +```text +req_to_token[req_pool_idx, token_pos] = physical_kv_slot +``` + +### 3.2 `TokenToKVPool` + +它是真正存物理 K/V 的地方。 + +根据注意力形式不同,常见有两类: + +- `MHATokenToKVPool` +- `MLATokenToKVPool` + +### 3.3 MHA KV cache 形状 + +文件: + +- `sglang/python/sglang/srt/mem_cache/memory_pool.py` + +MHA 下,每层通常有两块 buffer: + +- `k_buffer[layer]` +- `v_buffer[layer]` + +shape: + +- `k_buffer[layer]`: `[(size + page_size), num_kv_heads, head_dim]` +- `v_buffer[layer]`: `[(size + page_size), num_kv_heads, v_head_dim]` + +这里的第一维就是 **物理 token slot**。 + +也就是说: + +- `loc = 12345` +- `k_buffer[layer][12345]` +- `v_buffer[layer][12345]` + +就是这个 token 在该层的物理 KV 存储位置。 + +额外的 `+ page_size` 是为了预留 padding / dummy 写入空间。源码里有一句很关键: + +- padded slot 0 用于 padded token 的 dummy output write + +所以它不是严格只分配 `size` 个可见 token 位置,而是多留了一点缓冲。 + +### 3.4 MLA KV cache 形状 + +MLA 下不是单独一块 K、一块 V,而是一个合并后的 latent KV buffer。 + +shape: + +- `kv_buffer[layer]`: `[(size + page_size), 1, kv_cache_dim]` + +其中: + +- `kv_cache_dim = kv_lora_rank + qk_rope_head_dim` + - 对 DeepSeek MLA,通常就是 latent KV 部分加 rope 部分 + +这意味着: + +- MHA:一个 slot 对应 `K` 和 `V` +- MLA:一个 slot 对应一个融合后的 latent cache 向量 + +对 DeepSeek MLA,常见理解方式是: + +- 前半段:`kv_a` / latent KV +- 后半段:`k_pe` / rope 部分 + + +## 4. `req_to_token` 和 `out_cache_loc` 的关系 + +调度器在每轮 forward 前,会先给新 token 分配物理 slot,得到: + +- `out_cache_loc` + +shape: + +- `[num_new_tokens]` + +语义: + +- 本轮所有新增 token 应该写到哪些物理 KV slot + +然后再把这些 slot 填回 `req_to_token` 的对应位置。 + +可以画成: + +```mermaid +flowchart TD + A[request A / req_pool_idx=7] + B[request B / req_pool_idx=9] + C[out_cache_loc = new physical slots] + D[assign_req_to_token_pool] + E[ReqToTokenPool.req_to_token] + F[TokenToKVPool buffers] + + A --> C + B --> C + C --> D --> E + E --> F +``` + +本质上: + +- `out_cache_loc` 决定“新 token 写哪里” +- `req_to_token` 记录“逻辑位置到物理 slot 的长期映射” + + +## 5. `ForwardMetadata` 核心字段总览 + +下面重点解释这些字段: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `max_q_len` +- `max_kv_len` +- `page_table` +- `kv_lens` + +### 5.1 一个总表 + +| 字段 | 常见 shape | 主要用于 | 一句话语义 | +|------|------------|----------|------------| +| `kv_indptr` | `[bs + 1]` | MLA | KV flatten 后每个 request 的段边界 | +| `kv_indices` | `[sum(kv_lens)]` | MLA | flatten 后每个 KV token 对应的物理 slot | +| `qo_indptr` | `[bs + 1]` | MLA / speculative | Q flatten 后每个 request 的段边界 | +| `kv_last_page_len` | `[bs]` | MLA paged kernel | 每个 request 最后一个 page 里有多少有效 token | +| `max_q_len` | `int` | 所有 attention kernel | batch 内单个 request 的最大 query 长度 | +| `max_kv_len` | `int` or `None` | extend / prefill | batch 内单个 request 的最大 KV 长度 | +| `page_table` | `[bs, max_pages]` | 非 MLA | request -> page id 的二维表 | +| `kv_lens` | `[bs]` | 非 MLA | 每个 request 的 KV 长度 | + + +## 6. `kv_indptr` 是什么 + +### 6.1 语义 + +`kv_indptr` 是一个 CSR 风格的前缀和数组。 + +shape: + +- `[bs + 1]` + +语义: + +- 第 `i` 个 request 的 KV 段,在 `kv_indices` 中的范围是: + - `[kv_indptr[i], kv_indptr[i + 1])` + +所以它不是“KV 长度本身”,而是: + +- `flatten 之后每段的起止边界` + +### 6.2 它通常怎么构造 + +典型构造方式: + +```text +kv_indptr[0] = 0 +kv_indptr[1:] = cumsum(kv_lens) +``` + +其中: + +- decode 下,`kv_lens` 往往就是 `seq_lens` +- target_verify 下,MLA 常是 `seq_lens + draft_token_num` +- draft_extend 下,可能来自 speculative 专门生成的 prefill 参数 + +### 6.3 例子 + +假设 batch 里有两个 request: + +- request A 的 KV 长度 = 5 +- request B 的 KV 长度 = 3 + +那么: + +```text +kv_lens = [5, 3] +kv_indptr = [0, 5, 8] +``` + +表示: + +- request A 对应 `kv_indices[0:5]` +- request B 对应 `kv_indices[5:8]` + + +## 7. `kv_indices` 是什么 + +### 7.1 语义 + +`kv_indices` 是一个 flatten 后的一维数组。 + +shape: + +- `[sum(kv_lens)]` + +语义: + +- 它的每个元素都是 **物理 KV slot id** +- 这些 slot id 来自 `req_to_token` + +换句话说: + +- `kv_indices` 是“这一步 attention 真正要访问哪些物理 KV token” + +### 7.2 它和 `req_to_token` 的关系 + +`create_flashinfer_kv_indices_triton(...)` 会根据: + +- `req_pool_indices` +- `req_to_token` +- `kv_lens` +- `kv_indptr` + +把每个 request 对应的那一段 `req_to_token[row, :kv_len]` +抽出来,拼成一个一维的 `kv_indices`。 + +### 7.3 例子 + +假设: + +- `req_pool_indices = [7, 9]` +- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` +- `req_to_token[9, 0:3] = [200, 201, 220]` + +那么: + +```text +kv_indptr = [0, 5, 8] +kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] +``` + +这就表示: + +- request A 的 attention 访问物理 slot `100,101,102,103,120` +- request B 的 attention 访问物理 slot `200,201,220` + +可以把它理解为: + +```mermaid +flowchart LR + A["req_to_token[row=7] = [100,101,102,103,120,...]"] + B["req_to_token[row=9] = [200,201,220,...]"] + C["kv_indices = [100,101,102,103,120,200,201,220]"] + + A --> C + B --> C +``` + + +## 8. `qo_indptr` 是什么 + +### 8.1 语义 + +`qo_indptr` 和 `kv_indptr` 是对称的,但它描述的是 **Q / output 侧**。 + +shape: + +- `[bs + 1]` + +语义: + +- 第 `i` 个 request 的 query 段,在 flatten 后的 Q 张量中的范围是: + - `[qo_indptr[i], qo_indptr[i + 1])` + +### 8.2 为什么它很重要 + +attention backend 经常把 batch 里的 query token flatten 成一个二维/三维张量去跑 kernel。 + +这时 kernel 需要知道: + +- 哪些 query 属于 request A +- 哪些 query 属于 request B + +`qo_indptr` 就是这份分段说明书。 + +### 8.3 不同模式下的典型含义 + +- decode + - 每个 request 只有 1 个 query + - 所以常见是 `[0, 1, 2, ..., bs]` +- 普通 extend / prefill + - 每个 request 的 query 数就是 `extend_seq_lens[i]` + - 所以通常是 `cumsum(extend_seq_lens)` +- target_verify + - 每个 request 通常有 `draft_token_num` 个 query + - 所以常是 `[0, d, 2d, 3d, ...]` +- draft_extend + - 每个 request 的 query 数可能不同 + - 常来自 `accept_length` 或 `extend_seq_lens` + +### 8.4 例子 + +假设有两个 request: + +- request A 本轮新增 query = 3 +- request B 本轮新增 query = 2 + +那么: + +```text +extend_seq_lens = [3, 2] +qo_indptr = [0, 3, 5] +``` + +表示: + +- request A 的 query 是 flatten Q 中的 `[0:3]` +- request B 的 query 是 flatten Q 中的 `[3:5]` + + +## 9. `kv_last_page_len` 是什么 + +### 9.1 语义 + +这是分页 KV cache 下很重要的一个辅助量。 + +shape: + +- `[bs]` + +语义: + +- 每个 request 的最后一个 page 里,有多少个有效 token + +因为 paged KV cache 不是要求每个 request 的长度都刚好是 `page_size` 的整数倍,所以最后一个 page 往往只有一部分有效。 + +### 9.2 例子 + +假设 `page_size = 4`: + +- request A 的 KV 长度 = 5 +- request B 的 KV 长度 = 3 + +那么: + +- request A 有 2 个 page,最后一个 page 有 1 个有效 token +- request B 有 1 个 page,最后一个 page 有 3 个有效 token + +对应: + +```text +kv_last_page_len = [1, 3] +``` + + +## 10. `max_q_len` 和 `max_kv_len` + +### 10.1 `max_q_len` + +语义: + +- batch 内单个 request 的最大 query 长度 + +常见来源: + +- decode: `1` +- 普通 extend: `max(extend_seq_lens)` +- target_verify: `draft_token_num` +- draft_extend: 常是 `max(extend_seq_lens)` 或 `max(accept_length)` + +shape: + +- Python `int` + +作用: + +- kernel 需要知道 batch 内最大 query 段长度,来决定 tile / workspace / pad 方式 + +### 10.2 `max_kv_len` + +语义: + +- batch 内单个 request 的最大 KV 长度 + +常见来源: + +- 普通 extend / prefill:`max(seq_lens)` +- 某些 decode / verify MLA 路径里可能不单独存,或者设成 `None` + +shape: + +- Python `int` 或 `None` + + +## 11. `page_table` 和 `kv_lens` + +这两个字段更偏 **非 MLA** 路径,是 `kv_indptr/kv_indices` 的 page 化替代表示。 + +### 11.1 `page_table` + +shape: + +- `[bs, max_num_pages_per_request]` + +语义: + +- 每一行对应一个 request +- 每个元素是一个 physical page id + +它不是 token-level 的 flatten 索引,而是 page-level 的二维映射。 + +### 11.2 `kv_lens` + +shape: + +- `[bs]` + +语义: + +- 每个 request 当前 KV 长度 + +kernel 会结合: + +- `page_table` +- `kv_lens` +- `page_size` + +来知道每个 request 该读哪些 page、最后一页有多少有效 token。 + +### 11.3 这组字段是不是只给 MLA 用 + +不是。 + +`ForwardMetadata` 更准确地说是一个: + +- **统一容器** + +它里面同时装了: + +- MLA 常用的 metadata +- MHA 常用的 metadata +- 两边都可能用到的通用字段 + +可以粗略分成三类: + +| 类别 | 字段 | +|------|------| +| 更偏 MLA | `kv_indptr`, `kv_indices`, `qo_indptr`, `kv_last_page_len` | +| 更偏 MHA | `page_table`, `kv_lens`, `pa_metadata_*` | +| 通用 | `max_q_len`, `max_kv_len` | + +也就是说: + +- **不是只有 MLA 才会创建 `ForwardMetadata`** +- 而是 **MLA 和 MHA 共用这个 dataclass** +- 只是不同 kernel 最终只消费其中的一部分字段 + +### 11.4 MHA 的 metadata 代码在哪里 + +如果想看 `sgl_attn_backend.py` 里 **MHA 真正的 metadata 路径**,主要看这几段: + +- `_init_decode_mha()` +- `_init_extend_mha()` +- `_build_pa_metadata_for_decode()` +- `_build_pa_metadata_for_prefill()` + +含义可以概括成: + +- `decode` + - 优先看 `page_table`, `kv_lens` + - 如果走 `pa_persistent_fwd`,再看 `pa_metadata_*` +- `extend / prefill` + - 主要看 `max_q_len`, `max_kv_len` + - page 化路径下也会继续依赖 `page_table`, `kv_lens` + +换句话说: + +- **MLA 更像 “token-level flattened 索引驱动”** +- **MHA 更像 “page-table / context-len 驱动”** + +### 11.5 为什么 MHA 通常不需要 `kv_last_page_len` + +这个问题最容易和 MLA 搞混。 + +核心原因是: + +- MHA 在这个 backend 里通常走的是: + - `page_table + kv_lens` + - 或者 `pa_metadata_*` +- MLA 则更依赖: + - `kv_indptr + kv_indices + kv_last_page_len` + +对 MHA 来说,kernel 经常直接拿到: + +- 每个 request 的 `context_lens` +- 每个 request 对应哪些 page(`page_table`) +- 固定的 `page_size` + +于是: + +- 最后一页有多少有效 token +- 可以由 `context_lens % page_size` +- 或更高层 page metadata 直接推出来 + +所以 MHA 不一定需要把: + +- “最后一个 page 的有效长度” + +单独存成 `kv_last_page_len`。 + +而 MLA 的 paged kernel 实现更偏 token-flatten / ragged 索引驱动,显式传: + +- `kv_last_page_len` + +会更直接、更方便。 + +### 11.6 为什么 MHA 通常不需要 `qo_indptr` + +`qo_indptr` 的本质是: + +- flatten 后 query 段的边界表 + +它在 MLA 里很重要,因为 MLA kernel 经常直接消费: + +- ragged / flatten 的 Q 段 +- 对应的 KV flatten 段 + +而 MHA 在这个 plugin 里常见有两类路径: + +#### 路径一:decode 的 `pa_fwd_asm` / `pa_persistent_fwd` + +这类 kernel 更偏: + +- `block_tables = page_table` +- `context_lens = kv_lens` + +decode 下每个 request 本来就只有 1 个 query,所以 query 分段是隐含的: + +- batch 第 0 个 query 属于 request 0 +- batch 第 1 个 query 属于 request 1 + +这时单独维护 `qo_indptr` 不是必须的。 + +#### 路径二:extend 的 `flash_attn_varlen_func` + +这条路在当前 plugin 里更依赖: + +- 显式传入的 `q`, `k`, `v` +- `max_q_len`, `max_kv_len` +- 以及运行时构出来的 `cu_seqlens_q` + +这里 query 的分段信息已经由: + +- 输入张量本身 +- `cu_seqlens_q` + +表达出来了,所以 `qo_indptr` 也不是核心字段。 + +因此可以把它理解成: + +- **MLA 喜欢把 Q 段边界显式放进 metadata** +- **MHA 更常把 Q 段边界隐含在输入张量和专用 kernel 参数里** + +### 11.7 为什么 MHA 通常不需要 `kv_indptr + kv_indices` + +`kv_indptr + kv_indices` 的组合,本质上是在表达: + +- “把所有 request 的 KV token 拉平成一条长数组之后,每个 request 的 KV 段从哪里开始,到哪里结束” + +这是一种非常适合: + +- ragged token-level attention +- MLA flatten KV 访问 + +的表示法。 + +但 MHA 在 paged cache 下经常不需要把 KV 先 flatten 成 token 列表。 + +因为它可以直接用: + +- `page_table` +- `kv_lens` + +来表达同一件事: + +- 第 `i` 个 request 对应哪些 page +- 这些 page 中实际有多少 token 是有效的 + +可以类比为: + +- `kv_indptr + kv_indices` + - 是 token-level 的“稀疏展开形式” +- `page_table + kv_lens` + - 是 page-level 的“块索引形式” + +两者本质都在回答: + +- “本轮 attention 应该读哪些 KV” + +只是表达层次不同。 + +所以不是说 MHA **完全不能** 用 `kv_indptr + kv_indices`, +而是: + +- 在当前 backend 的主实现里,MHA 的更自然表示是 `page_table + kv_lens` +- `kv_indptr + kv_indices` 在 MHA 路径里通常不是主角 + +### 11.8 一个简化对照表 + +| 场景 | 更核心的 metadata | +|------|-------------------| +| MLA decode | `kv_indptr`, `kv_indices`, `qo_indptr`, `kv_last_page_len` | +| MLA extend | `kv_indptr`, `kv_indices`, `qo_indptr`, `max_q_len`, `max_kv_len` | +| MHA decode | `page_table`, `kv_lens`, `pa_metadata_*` | +| MHA extend | `max_q_len`, `max_kv_len`,以及必要时的 `page_table`, `kv_lens` | + + +## 12. 三个最重要的例子 + +### 12.1 例子一:普通 decode + +假设: + +- `bs = 2` +- `seq_lens = [5, 3]` +- `req_pool_indices = [7, 9]` +- `page_size = 4` + +并且: + +- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` +- `req_to_token[9, 0:3] = [200, 201, 220]` + +那么: + +```text +kv_indptr = [0, 5, 8] +kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] +qo_indptr = [0, 1, 2] +kv_last_page_len= [1, 3] +max_q_len = 1 +``` + +含义: + +- 每个 request 只解一个 token +- query 一共 2 个 +- 但每个 query 需要看到自己已有的完整上下文 KV + +### 12.2 例子二:普通 extend / prefill + +假设: + +- request A:prefix 长度 5,本轮 extend 3 个 token,总长 8 +- request B:prefix 长度 7,本轮 extend 2 个 token,总长 9 + +则: + +```text +extend_prefix_lens = [5, 7] +extend_seq_lens = [3, 2] +seq_lens = [8, 9] +qo_indptr = [0, 3, 5] +max_q_len = 3 +max_kv_len = 9 +``` + +理解方式: + +- Q 侧看的是“本轮新增的 3/2 个 token” +- KV 侧看的是“整个请求当前总长度 8/9” + +也就是说: + +- `qo_indptr` 由本轮新增 query 决定 +- `kv_indptr / max_kv_len` 由总上下文长度决定 + +### 12.3 例子三:`TARGET_VERIFY` + +假设: + +- `bs = 2` +- `seq_lens = [5, 3]` +- `draft_token_num = 4` + +那么对每个 request: + +- 本轮要验证 4 个 draft token +- 但每个 query 能看到的 KV 长度不是原始 `seq_lens` +- 而是 `seq_lens + draft_token_num` + +于是: + +```text +qo_indptr = [0, 4, 8] +kv_lens = [9, 7] +kv_indptr = [0, 9, 16] +max_q_len = 4 +``` + +这个例子非常重要,因为它说明: + +- verify 不是普通 decode +- 也不是普通 extend +- 它会同时改变 Q 的分段和 KV 的可见长度 + +这也是为什么 speculative path 不能简单复用普通 extend metadata。 + + +## 13. 一个完整的“逻辑位置 -> 物理 KV”例子 + +假设 request A 当前已经有 5 个 token: + +```text +req_pool_idx = 7 +req_to_token[7, 0:5] = [100, 101, 102, 103, 120] +``` + +这表示: + +- 逻辑 token 0 -> physical slot 100 +- 逻辑 token 1 -> physical slot 101 +- 逻辑 token 2 -> physical slot 102 +- 逻辑 token 3 -> physical slot 103 +- 逻辑 token 4 -> physical slot 120 + +如果本轮又新分配了两个 slot: + +```text +out_cache_loc = [130, 131] +``` + +并把它们写回 request A 的逻辑位置 5、6: + +```text +req_to_token[7, 5] = 130 +req_to_token[7, 6] = 131 +``` + +那么 request A 的完整逻辑到物理映射就变成: + +```text +[100, 101, 102, 103, 120, 130, 131] +``` + +之后 attention metadata 只要知道: + +- `req_pool_idx = 7` +- `kv_len = 7` + +就能通过 `req_to_token` 自动构造出: + +```text +kv_indices = [100, 101, 102, 103, 120, 130, 131] +``` + + +## 14. 可以如何快速判断一个字段该不该看 + +如果你在 debug `sgl_attn_backend.py`,可以用下面这个经验法则: + +- 想知道“这轮每个 request 有多少 query” + - 看 `qo_indptr`, `max_q_len`, `extend_seq_lens` +- 想知道“这轮每个 request 的 KV 能看到多长” + - 看 `kv_indptr`, `kv_last_page_len`, `max_kv_len`, `kv_lens` +- 想知道“这些 KV 实际在 cache 里是哪几个 slot” + - 看 `kv_indices` +- 想知道“request 的逻辑 token 位置和物理 slot 怎么对应” + - 看 `req_to_token_pool.req_to_token` +- 想知道“本轮新 token 会写去哪里” + - 看 `out_cache_loc` + + +## 15. 最后总结 + +如果只记三句话: + +1. `req_to_token` 是 **逻辑 token 位置 -> 物理 KV slot** 的长期映射表。 +2. `kv_indptr + kv_indices` 是把这张长期映射表裁成“本轮 attention 真正要访问的 KV 列表”。 +3. `qo_indptr` 是 query 侧的分段表,告诉 kernel flatten 后哪些 query 属于哪个 request。 + +所以: + +- `scheduler` 决定 batch 语义 +- `req_to_token / out_cache_loc` 决定 cache 物理布局 +- `ForwardMetadata` 把二者翻译成 kernel 真正能消费的索引 + +这就是 `sgl_attn_backend.py` 里这些字段的核心意义。 diff --git a/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md b/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md new file mode 100644 index 000000000..3b7bb3978 --- /dev/null +++ b/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md @@ -0,0 +1,692 @@ +# SGLang KV Cache Storage Guide + +## 文档目的 + +这篇文档专门解释: + +- SGLang 里 KV cache 是怎么存的 +- MHA 和 MLA 的存储结构有什么区别 +- `ReqToTokenPool`、allocator、`TokenToKVPool` 分别负责什么 +- 常见张量 shape 是什么 +- 逻辑 token 是怎么映射到物理 KV slot 的 + +本文聚焦 **SGLang 本身的 KV cache 存储**,不展开 vLLM 或 ATOM native 的实现差异。 + + +## 一句话理解 + +SGLang 的 KV cache 不是“每个 request 一块连续显存”。 + +它更像一个两级系统: + +1. `ReqToTokenPool` + - 维护逻辑映射 + - 回答“这个 request 的第 t 个 token 存在哪个 slot” +2. `TokenToKVPool` + - 维护物理存储 + - 真正把 K/V 或 MLA latent 写进 GPU buffer + +可以先记住这个核心公式: + +```text +req_to_token[req_pool_idx, token_pos] = physical_slot +``` + + +## 1. 先统一几个概念 + +### 1.1 request slot + +SGLang 不直接拿 request id 当数组下标。 + +它会先从 `ReqToTokenPool` 给每个活跃 request 分一个: + +- `req_pool_idx` + +这就是这个 request 在 `req_to_token` 表里的“行号”。 + +### 1.2 token slot + +一个 token slot 表示: + +- 某一层 KV cache 里,一个 token 对应的一行物理存储位置 + +它通常是一个全局整数,例如: + +- `100` +- `101` +- `2025` + +### 1.3 page + +当 `page_size > 1` 时,slot 会按 page 分组。 + +可以理解成: + +```text +一个 page = page_size 个连续 token slot +``` + +于是: + +```text +slot = page_id * page_size + page_offset +``` + +### 1.4 `out_cache_loc` + +`out_cache_loc` 是本轮新分配出来的物理 slot 列表。 + +shape: + +- extend / prefill:`[extend_num_tokens]` +- decode:通常是 `[bs * token_per_req]` + +它表示: + +- 本轮新 token 应该写到 KV cache 的哪些物理位置 + + +## 2. 总体架构 + +可以把 SGLang 的 KV cache 存储看成三层: + +- request 级逻辑层 +- token/page 分配层 +- 物理存储层 + +```mermaid +flowchart LR + A[Request] + B[req_pool_idx] + C[ReqToTokenPool.req_to_token] + D[out_cache_loc / allocator] + E[TokenToKVPool] + F[Physical KV buffers] + + A --> B + B --> C + D --> C + D --> E --> F + C --> F +``` + +更具体一点: + +- `ReqToTokenPool` + - 存“逻辑 token 位置 -> 物理 slot” +- allocator + - 决定这一步还能分到哪些新 slot +- `TokenToKVPool` + - 存每层真实的 K/V 或 MLA latent + + +## 3. 第一层:`ReqToTokenPool` + +文件: + +- `sglang/python/sglang/srt/mem_cache/memory_pool.py` + +核心类: + +- `ReqToTokenPool` + +核心张量: + +- `req_to_token` + +shape: + +- `[req_pool_size, max_context_len]` + +dtype: + +- `int32` + +语义: + +- 第 0 维:request slot,也就是 `req_pool_idx` +- 第 1 维:该 request 的逻辑 token 位置 +- 元素值:这个逻辑 token 对应的物理 KV slot + +也就是: + +```text +req_to_token[req_pool_idx, token_pos] = physical_slot +``` + +### 3.1 例子 + +假设 request A 被分到: + +- `req_pool_idx = 7` + +并且当前已经有 5 个 token: + +```text +req_to_token[7, 0:5] = [100, 101, 102, 103, 120] +``` + +那么它的逻辑到物理映射就是: + +- token 0 -> slot 100 +- token 1 -> slot 101 +- token 2 -> slot 102 +- token 3 -> slot 103 +- token 4 -> slot 120 + +注意: + +- 这里的 slot 不要求连续 +- 因为分页分配、复用、evict 都可能让物理位置不连续 + + +## 4. 第二层:allocator + +allocator 负责: + +- 从可用的 KV 空间中分配新的 slot 或 page +- 返回 `out_cache_loc` +- 再把结果写回 `req_to_token` + +SGLang 里常见有两类 allocator: + +- `TokenToKVPoolAllocator` + - `page_size = 1` + - 更像 token 粒度的平铺分配 +- `PagedTokenToKVPoolAllocator` + - `page_size > 1` + - 更像 page 粒度分配 + +相关文件: + +- `sglang/python/sglang/srt/mem_cache/allocator.py` +- `sglang/python/sglang/srt/mem_cache/common.py` + +### 4.1 `page_size = 1` + +这时 allocator 的视角非常简单: + +- 一个 free slot 就是一个 free token position + +分配出来的 `out_cache_loc` 可以直接看成: + +- 一串 token slot id + +源码里还有一个关键细节: + +- slot `0` 被保留给 padded token / dummy write + +所以真正可分配的 slot 常从 `1` 开始。 + +### 4.2 `page_size > 1` + +这时 allocator 虽然内部按 page 管理, +但对上层仍然返回: + +- token-level 的 `out_cache_loc` + +也就是说,上层最终看到的还是: + +- 这一步每个新 token 具体写到哪个 slot + +只是这些 slot 是由 page allocator 算出来的。 + +### 4.3 extend 时 allocator 做了什么 + +`alloc_for_extend()` 的语义可以概括成: + +1. 先给 request 分配 `req_pool_idx` +2. 再根据 prefix 长度和目标 seq 长度,分配这一步新增 token 的物理 slot +3. 生成 `out_cache_loc` +4. 把这些新 slot 写回 `req_to_token` + +所以: + +- `out_cache_loc` 是“这一步新 token 的写入位置” +- `req_to_token` 是“整个 request 的长期索引表” + +### 4.4 decode 时 allocator 做了什么 + +decode 最常见是每个 request 增加 1 个 token。 + +这时: + +- allocator 为每个 request 分 1 个新 slot +- `out_cache_loc` 的长度通常就是 batch size +- 然后把这个新 slot 写到 `req_to_token[req_pool_idx, 当前 seq_len]` + + +## 5. 第三层:物理 KV 存储 + +这层才是真正的大显存 buffer。 + +SGLang 里和 attention 相关的主要有: + +- `MHATokenToKVPool` +- `MLATokenToKVPool` + +两者最大的区别在于: + +- MHA 存 K 和 V 两份 buffer +- MLA 存一份 packed latent buffer + + +## 6. MHA 的 KV cache 存储 + +文件: + +- `sglang/python/sglang/srt/mem_cache/memory_pool.py` + +核心类: + +- `MHATokenToKVPool` + +### 6.1 核心 buffer + +每层有两份物理 buffer: + +- `k_buffer[layer]` +- `v_buffer[layer]` + +shape: + +- `k_buffer[layer]`: `[(size + page_size), num_kv_heads, head_dim]` +- `v_buffer[layer]`: `[(size + page_size), num_kv_heads, v_head_dim]` + +这里: + +- 第 0 维是物理 slot +- 第 1 维是 KV heads +- 第 2 维是每个 head 的维度 + +`size + page_size` 的原因是: + +- 除了正常容量,还预留了 padding / dummy 写入空间 + +### 6.2 怎么写入 + +写入接口通常是: + +- `set_kv_buffer(layer, loc, cache_k, cache_v, ...)` + +其中: + +- `loc.shape = [num_tokens]` +- `cache_k.shape = [num_tokens, num_kv_heads, head_dim]` +- `cache_v.shape = [num_tokens, num_kv_heads, v_head_dim]` + +语义就是: + +```text +k_buffer[layer][loc[i]] = cache_k[i] +v_buffer[layer][loc[i]] = cache_v[i] +``` + +### 6.3 怎么读 + +读取接口通常是: + +- `get_key_buffer(layer_id)` +- `get_value_buffer(layer_id)` +- `get_kv_buffer(layer_id)` + +attention backend 会根据 `req_to_token` 算出的 slot, +去这些 buffer 里 gather 对应位置。 + + +## 7. MLA 的 KV cache 存储 + +文件: + +- `sglang/python/sglang/srt/mem_cache/memory_pool.py` + +核心类: + +- `MLATokenToKVPool` + +### 7.1 核心 buffer + +MLA 下,每层通常只有一份主 buffer: + +- `kv_buffer[layer]` + +shape: + +- `[(size + page_size), 1, kv_cache_dim]` + +其中: + +```text +kv_cache_dim = kv_lora_rank + qk_rope_head_dim +``` + +这表示: + +- 每个物理 slot 存的是一段 packed latent KV +- 不是标准 MHA 意义上的分离 K / V + +### 7.2 逻辑拆分 + +对于 DeepSeek MLA,通常可以把这段 packed buffer 理解成: + +- 前半段:`kv_a` / latent KV +- 后半段:`k_pe` / rope 相关部分 + +也就是说,一个 slot 里实际上装的是: + +```text +[cache_k_nope | cache_k_rope] +``` + +### 7.3 写入接口 + +MLA 常见有两种写法: + +- `set_kv_buffer(...)` + - 直接把 packed cache 写进去 +- `set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)` + - 分别传 latent 部分和 rope 部分,由底层 helper 拼到一起 + +典型输入 shape: + +- `loc`: `[num_tokens]` +- `cache_k_nope`: `[num_tokens, 1, kv_lora_rank]` +- `cache_k_rope`: `[num_tokens, 1, qk_rope_head_dim]` + +### 7.4 读取接口 + +MLA 也有专门的读取接口: + +- `get_mla_kv_buffer(layer, loc, dst_dtype)` + +返回: + +- `cache_k_nope`: `[num_tokens, 1, kv_lora_rank]` +- `cache_k_rope`: `[num_tokens, 1, qk_rope_head_dim]` + +所以 MLA 的“读出来再用”其实也是把 packed storage 再拆回两部分。 + + +## 8. MHA 和 MLA shape 对照表 + +| 项目 | MHA | MLA | +|------|-----|-----| +| 主 buffer 数量 | 2 份:K / V | 1 份:packed latent | +| 每层物理 shape | K:`[slots, Hkv, Dk]` V:`[slots, Hkv, Dv]` | `[slots, 1, kv_lora_rank + qk_rope_head_dim]` | +| 第 0 维含义 | 物理 token slot | 物理 token slot | +| 典型写入接口 | `set_kv_buffer()` | `set_mla_kv_buffer()` | +| 典型读取接口 | `get_kv_buffer()` | `get_mla_kv_buffer()` | +| 逻辑视角 | 标准 K/V cache | latent KV + rope 部分 | + + +## 9. `out_cache_loc`、`req_to_token`、buffer 的关系 + +可以把一次写入过程画成: + +```mermaid +flowchart TD + A[allocator 分配新 slot] + B[out_cache_loc] + C[写回 req_to_token] + D[attention forward 产出 K/V 或 MLA latent] + E[set_kv_buffer / set_mla_kv_buffer] + F[物理 KV buffer] + + A --> B + B --> C + B --> E + D --> E --> F +``` + +这里有两个并行动作: + +- `out_cache_loc` 被写回 `req_to_token` +- 同时新算出来的 KV 被写进物理 buffer + +这样下一轮只要知道: + +- `req_pool_idx` +- 当前 `seq_len` + +就能通过 `req_to_token` 找到历史 token 对应的所有物理 slot。 + + +## 10. 例子一:非分页 MHA decode + +假设: + +- `page_size = 1` +- batch 有 2 个 request +- `req_pool_indices = [7, 9]` +- 当前 `seq_lens = [5, 3]` + +已有映射: + +```text +req_to_token[7, 0:5] = [100, 101, 102, 103, 120] +req_to_token[9, 0:3] = [200, 201, 220] +``` + +本轮 decode,每个 request 新增 1 个 token,allocator 返回: + +```text +out_cache_loc = [130, 221] +``` + +然后写回: + +```text +req_to_token[7, 5] = 130 +req_to_token[9, 3] = 221 +``` + +于是下一轮: + +- request A 的完整上下文 slot 是 `[100,101,102,103,120,130]` +- request B 的完整上下文 slot 是 `[200,201,220,221]` + +物理存储上则是: + +```text +k_buffer[layer][130] = new_k_for_A +v_buffer[layer][130] = new_v_for_A + +k_buffer[layer][221] = new_k_for_B +v_buffer[layer][221] = new_v_for_B +``` + + +## 11. 例子二:分页 MHA extend + +假设: + +- `page_size = 4` +- request A 的 prefix 长度 = 5 +- 本轮 extend 后总长度 = 8 + +也就是: + +- prefix token 已经占了 5 个逻辑位置 +- 本轮要再写 3 个 token + +假设它当前最后一个已用 slot 是: + +```text +last_loc = 120 +``` + +而这个 `120` 恰好在某个 page 的中间。 + +那么 allocator 在 `alloc_paged_token_slots_extend()` 里大概会做两件事: + +1. 先尽量把当前未满的最后一个 page 填满 +2. 如果还不够,再分配新 page + +可能得到: + +```text +out_cache_loc = [121, 122, 200] +``` + +这表示: + +- 前两个 token 继续写进原 page 的剩余位置 +- 第三个 token 写进新 page 的第一个 slot + +然后写回: + +```text +req_to_token[7, 5:8] = [121, 122, 200] +``` + +所以分页 allocator 的重点不是“返回 page id”,而是: + +- **最终依然返回 token-level slot ids** + + +## 12. 例子三:MLA 写入和读取 + +假设: + +- `kv_lora_rank = 512` +- `qk_rope_head_dim = 64` + +那么: + +```text +kv_cache_dim = 576 +``` + +对某层来说,MLA 物理 buffer 的 shape 可能是: + +```text +kv_buffer[layer].shape = [num_slots, 1, 576] +``` + +本轮有 2 个新 token: + +```text +loc = [130, 131] +cache_k_nope.shape = [2, 1, 512] +cache_k_rope.shape = [2, 1, 64] +``` + +调用: + +```text +set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope) +``` + +后,可以理解为: + +```text +kv_buffer[layer][130] = concat(cache_k_nope[0], cache_k_rope[0]) +kv_buffer[layer][131] = concat(cache_k_nope[1], cache_k_rope[1]) +``` + +后续 attention 需要读取历史 cache 时,再通过: + +```text +get_mla_kv_buffer(layer, loc=[100,101,130], dst_dtype=bf16) +``` + +拿回: + +- `cache_k_nope`: `[3, 1, 512]` +- `cache_k_rope`: `[3, 1, 64]` + + +## 13. 为什么 SGLang 要搞两层,而不是直接 request -> kv buffer + +因为推理服务不是静态 batch。 + +SGLang 的 request 会不断: + +- 加入 +- 完成 +- 被截断 +- 被 speculative verify / draft_extend 修改 +- 被分页 allocator 扩容 + +如果直接给每个 request 一块连续大 buffer: + +- 复用差 +- 容易碎片化 +- prefix cache / page cache 不好做 + +两层结构的好处是: + +- `ReqToTokenPool` + - 负责逻辑组织 +- `TokenToKVPool` + - 负责物理存储 + +这样: + +- request 的逻辑顺序可以变 +- 物理 slot 可以复用 +- page allocator 可以独立演化 +- MHA / MLA 只需要换底层 KV pool 的 shape,不用重写上层 request 索引系统 + + +## 14. 和 attention metadata 的关系 + +KV cache 存储本身只回答: + +- 数据放在哪里 + +attention metadata 还要回答: + +- 本轮到底读哪些 token +- 这些 token 该怎么分段 +- 对应哪个 request + +所以常见链路是: + +1. `req_to_token` + - 保存 request 的长期逻辑到物理映射 +2. `out_cache_loc` + - 保存本轮新 token 的新物理位置 +3. attention metadata + - 从 `req_to_token` 中抽出本轮真正要访问的那部分 slot + - 形成 `kv_indices` 或 `page_table` + +也就是说: + +- KV cache storage 是“数据库” +- attention metadata 是“查询结果” + + +## 15. 调试时最该先看什么 + +如果你在 debug SGLang KV cache,建议按这个顺序看: + +1. `req_pool_idx` + - 这个 request 映射到哪一行 +2. `req_to_token[row, :seq_len]` + - 当前逻辑 token 对应哪些物理 slot +3. `out_cache_loc` + - 本轮新 token 写到哪里 +4. `k_buffer / v_buffer` 或 `kv_buffer` + - 这些 slot 位置上实际存了什么 shape +5. attention metadata + - 例如 `kv_indices` / `page_table` + - 看本轮真正读的是不是你以为的那些 slot + + +## 16. 最后总结 + +如果只记 6 句话: + +1. `ReqToTokenPool.req_to_token` 是 SGLang KV cache 的逻辑索引总表。 +2. `out_cache_loc` 是本轮新 token 的物理写入位置。 +3. allocator 可能按 token 或 page 分配,但返回给上层的通常仍是 token-level slot。 +4. MHA 物理存储是两份: + - `k_buffer[layer]: [slots, Hkv, Dk]` + - `v_buffer[layer]: [slots, Hkv, Dv]` +5. MLA 物理存储是一份 packed buffer: + - `kv_buffer[layer]: [slots, 1, kv_lora_rank + qk_rope_head_dim]` +6. attention metadata 不是重复存储 KV cache,而是基于 `req_to_token` 再生成“本轮实际访问哪些 KV”的索引视图。 + +这就是 SGLang 里 KV cache 存储的核心结构。 diff --git a/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md b/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md new file mode 100644 index 000000000..26db38442 --- /dev/null +++ b/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md @@ -0,0 +1,910 @@ +# 2026-04-08 SGLang Speculative Decoding 架构笔记 + +## 文档目的 + +这份文档用于从 **SGLang 整体架构** 的角度梳理 speculative decoding +(推测解码)的实现方式,重点回答下面几个问题: + +- SGLang 在启动阶段如何决定是否进入 speculative decoding +- target model 和 draft model 是如何被构造和组织的 +- scheduler、worker、batch、attention backend 分别承担什么职责 +- EAGLE / EAGLE3 / NEXTN / STANDALONE / NGRAM 在 SGLang 里是怎样映射的 +- v1 与 v2(overlap/spec v2)在控制流和数据结构上有何差异 +- 遇到问题时,应该优先看哪些代码位置 + +本文会尽量从“系统设计”和“源码位置”两个维度同时展开,方便后续复盘。 + + +## 一句话总结 + +从架构上看,SGLang 的 speculative decoding 可以概括成: + +- **配置层** 先在 `ServerArgs` 中解析算法与 draft 参数 +- **模型配置层** 决定 draft model 应该映射成哪一个架构名 +- **调度层** 通过 `Scheduler` 把执行入口从普通 `TpModelWorker` 切到 + speculative orchestrator +- **worker 层** 维护 target worker 与 draft worker 的协作关系 +- **batch 层** 用 `ScheduleBatch -> ModelWorkerBatch -> ForwardBatch` 三层结构 + 将调度语义转为 GPU 执行语义 +- **attention/backend 层** 再按 `ForwardMode` 和 `SpecInput` 区分 + `decode / target_verify / draft_extend` + +也就是说,speculative decoding 在 SGLang 里不是“加一个 draft model”这么简单, +而是一整套跨: + +- 配置 +- 调度 +- 模型加载 +- batch 编排 +- attention metadata + +的系统设计。 + + +## 1. 启动入口:ServerArgs 如何决定 speculative decoding + +### 1.1 关键配置项 + +核心入口文件: + +- `sglang/python/sglang/srt/server_args.py` + +关键字段位置: + +- `speculative_algorithm`:约 `480` +- `speculative_draft_model_path`:约 `481` +- `speculative_num_steps`:约 `484` +- `speculative_num_draft_tokens`:约 `486` + +这些字段决定: + +- 用哪种 speculative 算法 +- draft model 从哪里加载 +- 每轮 draft 几步 +- 每轮最多提议多少 draft token + + +### 1.2 `NEXTN` 在 SGLang 里的真实含义 + +很多人第一次看会以为 `NEXTN` 是一条完全独立的 speculative runtime。 +实际上不是。 + +在: + +- `sglang/python/sglang/srt/server_args.py` +- `_handle_speculative_decoding()` 逻辑中 + +有一个关键规范化: + +- `NEXTN -> EAGLE` + +对应代码位置: + +- `server_args.py` 约 `2680-2681` + +这意味着: + +- 用户在命令行里写 `--speculative-algorithm NEXTN` +- 进入运行时后,SGLang 会把它归并到 `EAGLE` 这套 speculative worker 流程里 + +也就是说: + +- `NEXTN` 更像是“draft model 形态 / 语义” +- `EAGLE` 更像是“runtime orchestration 机制” + + +### 1.3 spec v2 与 overlap scheduler + +仍然是在: + +- `server_args.py` +- `_handle_speculative_decoding()` + +关键逻辑位置: + +- `2696-2716` + +SGLang 会做一件重要的系统级决策: + +- 如果 speculative 算法属于 `EAGLE / EAGLE3 / STANDALONE` +- 且环境变量 `SGLANG_ENABLE_SPEC_V2=True` +- 则开启 overlap schedule(即 spec v2) + +否则: + +- 会退回到不带 overlap 的传统路径(可以理解为 spec v1) + +同时还有一些额外约束: + +- spec v2 目前只支持 `topk = 1` +- 使用 speculative 时会关闭 mixed chunked prefill + + +### 1.4 DeepSeek / MTP 与 `speculative_draft_model_path` + +同一段逻辑里还有一个对 DeepSeek 很关键的行为: + +- 对 `DeepseekV3ForCausalLM`、`DeepseekV32ForCausalLM`、`GlmMoeDsaForCausalLM` + 等架构 +- 如果没有显式传 `speculative_draft_model_path` +- 会自动把它设成主模型路径 + +关键位置: + +- `server_args.py` 约 `2725-2748` + +这就是为什么日志里会有类似: + +- `DeepSeek MTP does not require setting speculative_draft_model_path.` + +的提示。 + +这说明 SGLang 把 DeepSeek MTP / NextN 看成是某种“和 target 模型强绑定”的 +draft 形态,而不是完全独立的小模型。 + + +## 2. 算法层:SpeculativeAlgorithm 与 worker 工厂 + +核心文件: + +- `sglang/python/sglang/srt/speculative/spec_info.py` + +### 2.1 算法枚举 + +关键枚举: + +- `SpeculativeAlgorithm` + +包含: + +- `EAGLE` +- `EAGLE3` +- `STANDALONE` +- `NGRAM` +- `NONE` + +关键位置: + +- `spec_info.py` 约 `15-23` + + +### 2.2 worker 工厂 + +最关键的方法: + +- `SpeculativeAlgorithm.create_worker()` + +关键位置: + +- `spec_info.py` 约 `52-105` + +这个函数负责把: + +- 算法类型 +- overlap 是否开启 +- multi-layer eagle 是否开启 + +映射成具体 worker 类。 + +典型映射关系: + +- `EAGLE + overlap` -> `EAGLEWorkerV2` +- `EAGLE + no overlap` -> `EAGLEWorker` +- `STANDALONE + overlap` -> `StandaloneWorkerV2` +- `STANDALONE + no overlap` -> `StandaloneWorker` +- `NGRAM` -> `NGRAMWorker` + + +### 2.3 什么叫 “supports_spec_v2” + +还有个很重要的方法: + +- `supports_spec_v2()` + +关键位置: + +- `spec_info.py` 约 `49-50` + +含义是: + +- 当前算法是否支持 overlap/spec v2 抽象 + +目前只有: + +- `EAGLE` +- `STANDALONE` + +对应为真。 + + +## 3. 调度层:Scheduler 如何把普通模型调度切成 speculative 调度 + +核心文件: + +- `sglang/python/sglang/srt/managers/scheduler.py` + +### 3.1 初始化顺序 + +关键位置: + +- `maybe_init_draft_worker()`:约 `527-554` +- `init_model_worker()`:约 `556-564` + +逻辑顺序是: + +1. 先建 `tp_worker` +2. 如果 speculative 开启,再建 `draft_worker` +3. 决定 `self.model_worker` 指向谁 + +代码语义: + +- 没开 speculative: + - `self.model_worker = self.tp_worker` +- 开了 speculative: + - `self.model_worker = self.draft_worker` + + +### 3.2 为什么 `self.model_worker = self.draft_worker` + +这里名字非常容易误导。 + +`scheduler.draft_worker` 并不一定是一个“纯 draft model worker”,它更像是: + +- speculative orchestrator + +例如: + +- `EAGLEWorker` +- `EAGLEWorkerV2` + +也就是说: + +- scheduler 并不是“把 target worker 替换掉了” +- 而是把执行入口切到了一个能同时协调 target + draft 的总控 worker + + +### 3.3 `run_batch()` 的差异 + +关键位置: + +- `scheduler.py` 约 `2360-2426` + +这里能看出 v1 与 v2 在 batch 抽象上的差异: + +- 开 overlap/spec v2 时: + - `worker_batch_or_batch = batch.get_model_worker_batch()` + - 下游主要处理 `ModelWorkerBatch` +- 非 overlap 的传统 speculative v1: + - 会直接把 `ScheduleBatch` 传给 `model_worker.forward_batch_generation()` + +这也是为什么你有时候会看到: + +- 有的 speculative worker 收的是 `ScheduleBatch` +- 有的 speculative worker 收的是 `ModelWorkerBatch` + +这不是 bug,而是新老抽象并存。 + + +## 4. Worker 层:target worker、draft worker 与 orchestrator 的关系 + +### 4.1 普通 target worker + +核心文件: + +- `sglang/python/sglang/srt/managers/tp_worker.py` + +`TpModelWorker` 是普通模型执行单元,负责: + +- 初始化 `ModelConfig` +- 初始化 `ModelRunner` +- 提供 `forward_batch_generation()` + +关键位置: + +- `_init_model_config()`:约 `320-336` +- `_init_model_runner()`:约 `338-358` +- `forward_batch_generation()`:约 `442+` + + +### 4.2 target 和 draft 的分流在哪里发生 + +在 `TpModelWorker._init_model_config()` 中: + +- 如果 `is_draft_worker=False`,用主模型路径 +- 如果 `is_draft_worker=True`,用 `speculative_draft_model_path` + +关键位置: + +- `tp_worker.py` 约 `323-336` + +这就是 target 和 draft 最底层模型配置分流的地方。 + + +### 4.3 `EAGLEWorker`(v1) + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_worker.py` + +`EAGLEWorker` 的特点: + +- 自己继承自 `TpModelWorker` +- 运行时同时持有: + - target worker + - 自己这套 draft model runner + +其 `forward_batch_generation()` 的大致逻辑是: + +- 如果是 extend: + - 先 `forward_target_extend` + - 再 `forward_draft_extend` +- 如果是 decode: + - 先 `draft()` + - 再 `verify()` + - 再 `forward_draft_extend_after_decode()` + +关键位置: + +- `eagle_worker.py` 约 `278-337` + + +### 4.4 `EAGLEWorkerV2`(spec v2 / overlap) + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_worker_v2.py` + +v2 与 v1 最大的结构差异是: + +- 外层 `EAGLEWorkerV2` 是 orchestrator +- 内层还有一个 `EagleDraftWorker` +- `EagleDraftWorker` 再内嵌一个真正的 draft `TpModelWorker` + +关键类: + +- `EagleDraftWorker`:约 `82` +- `EAGLEWorkerV2`:约 `607` + +这层设计的意义是: + +- 把 draft 逻辑进一步模块化 +- 更方便做 overlap 和独立的 draft graph / backend 管理 + + +### 4.5 `StandaloneWorkerV2` + +核心文件: + +- `sglang/python/sglang/srt/speculative/standalone_worker_v2.py` + +它和 `EAGLEWorkerV2` 的主要区别不是调度框架,而是: + +- draft model 不再共享 target 的 embedding / lm_head + +在源码里可以看到: + +- `StandaloneDraftWorker.init_lm_head()` 明确覆写为空实现 + +也就是: + +- standalone draft 用自己的一套 embedding/head +- 不走与 target 的共享逻辑 + + +## 5. 模型配置与 draft 架构改写 + +核心文件: + +- `sglang/python/sglang/srt/configs/model_config.py` + +### 5.1 `ModelConfig.from_server_args()` + +这是 target / draft `ModelConfig` 的统一入口。 + +关键位置: + +- `from_server_args()`:约 `238+` + + +### 5.2 `_config_draft_model()` + +最关键的方法: + +- `_config_draft_model()` + +关键位置: + +- `model_config.py` 约 `277-340` + +对于 DeepSeek: + +- 若 `is_draft_model=True` +- 且原始架构是 `DeepseekV3ForCausalLM` + +就会改写为: + +- `DeepseekV3ForCausalLMNextN` + +这是 draft 侧为什么会变成 NextN 壳子的核心原因。 + + +### 5.3 这和 ATOM plugin 的关系 + +这也是当前 `ATOM plugin` 只接管 target、没接管 draft 的根源: + +- ATOM external model package 只导出了 `DeepseekV3ForCausalLM` +- 并没有导出 `DeepseekV3ForCausalLMNextN` + +所以最终结果是: + +- target `DeepseekV3ForCausalLM` 被 external package 覆盖 +- draft `DeepseekV3ForCausalLMNextN` 仍走 upstream SGLang native + + +## 6. 模型实例化链路 + +### 6.1 `ModelRunner.load_model()` + +核心文件: + +- `sglang/python/sglang/srt/model_executor/model_runner.py` + +关键位置: + +- `load_model()`:约 `901-991` + +这里完成: + +- 构造 `LoadConfig` +- 选择 model loader +- 调 `loader.load_model(...)` + + +### 6.2 `ModelRunner._get_attention_backend()` + +关键位置: + +- `model_runner.py` 约 `1736-1746` + +这里会根据: + +- 是否是 draft worker +- 是否设置了 `speculative_draft_attention_backend` + +来决定 draft 用哪种 attention backend。 + +这是 speculative 与 attention backend 结合的一个重要入口。 + + +### 6.3 `_initialize_model()` + +核心文件: + +- `sglang/python/sglang/srt/model_loader/loader.py` + +关键位置: + +- `_initialize_model()`:约 `257-277` + +这是底层真正 `return model_class(**kwargs)` 的地方。 + +也就是说: + +- target 和 draft 在上层是两个不同 worker / model config +- 但底层最终都汇合到同一个模型实例化函数 + + +### 6.4 `get_model_architecture()` + +核心文件: + +- `sglang/python/sglang/srt/model_loader/utils.py` + +关键位置: + +- `get_model_architecture()`:约 `89-119` + +这个函数负责: + +- 看 `hf_config.architectures` +- 查 `ModelRegistry` +- 最终选出要实例化的 model class + +如果 external package 覆盖了同名架构,就会优先拿 external package 的类。 + + +## 7. 三层 batch 数据结构 + +核心文件: + +- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` +- `sglang/python/sglang/srt/managers/schedule_batch.py` + +### 7.1 三层抽象 + +`forward_batch_info.py` 文件开头就写得很清楚: + +- `ScheduleBatch -> ModelWorkerBatch -> ForwardBatch` + +含义: + +- `ScheduleBatch` + - scheduler 侧高层调度状态 + - CPU 语义更强 +- `ModelWorkerBatch` + - 给 worker 的中间态 +- `ForwardBatch` + - 最接近 kernel / backend 执行的低层态 + + +### 7.2 `ForwardMode` + +关键位置: + +- `forward_batch_info.py` 约 `74-179` + +推测相关最重要的几个 mode: + +- `TARGET_VERIFY` +- `DRAFT_EXTEND` +- `DRAFT_EXTEND_V2` +- `DECODE` +- `EXTEND` + +这里有个容易踩坑的点: + +- `TARGET_VERIFY` 在 `is_extend()` 里返回真 + +所以如果 backend 只按 “decode vs extend” 粗暴分流,很容易把 verify 错当普通 extend。 + + +### 7.3 `ScheduleBatch.get_model_worker_batch()` + +核心文件: + +- `sglang/python/sglang/srt/managers/schedule_batch.py` + +关键位置: + +- `get_model_worker_batch()`:约 `2175-2228` + +这一步负责把 scheduler 层状态打包成 `ModelWorkerBatch`。 + +关键理解: + +- 对 `decode_or_idle()`,`extend_seq_lens` 会被设成 `None` +- 对其他 extend 类路径,`extend_seq_lens` 来自 `self.extend_lens` + +这也是后面 verify 路径里经常出现 `extend_seq_lens=None` 的背景。 + + +## 8. speculative 的核心数据结构:SpecInput + +核心文件: + +- `sglang/python/sglang/srt/speculative/spec_info.py` + +关键抽象: + +- `SpecInput` +- `SpecInputType` + +类型包括: + +- `EAGLE_DRAFT` +- `EAGLE_VERIFY` +- `NGRAM_VERIFY` + +也就是说,speculative 不只是“多传几个 tensor”,而是有一套专门的数据结构协议。 + + +### 8.1 DeepSeek / EAGLE 相关具体实现 + +主要文件: + +- `sglang/python/sglang/srt/speculative/eagle_info.py` +- `sglang/python/sglang/srt/speculative/eagle_info_v2.py` + +这些文件负责: + +- draft 输入构造 +- verify 输入构造 +- draft token / hidden state / custom mask / positions 等 speculative 元数据 + + +## 9. EAGLE v1 的主执行流程 + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_worker.py` + +### 9.1 extend / prefill 阶段 + +关键位置: + +- `forward_batch_generation()`:约 `278-309` + +流程: + +1. target 先跑 extend / prefill +2. target 产出 hidden state +3. draft 用 target hidden state 再做 draft extend + + +### 9.2 decode 阶段 + +关键位置: + +- `forward_batch_generation()`:约 `310-337` + +流程: + +1. draft 先 propose +2. target 再 verify +3. draft 根据 verify 结果再 extend,为下一轮准备 + +这是一个典型的: + +- `draft -> target verify -> draft extend` + +链式协作过程。 + + +### 9.3 why target and draft share embed/head + +关键位置: + +- `eagle_worker.py` 约 `157-183` + +这里会显式调用: + +- `target_worker.model_runner.model.get_embed_and_head()` +- `draft_model_runner.model.set_embed_and_head(...)` + +说明: + +- upstream 的 EAGLE/NextN draft 设计默认依赖 target 的 embedding 和 lm_head + + +## 10. EAGLE v2 的主执行流程 + +核心文件: + +- `sglang/python/sglang/srt/speculative/eagle_worker_v2.py` + +### 10.1 prefill / extend + +关键位置: + +- `forward_batch_generation()`:约 `673-697` + +流程: + +1. target prefill +2. draft prefill +3. 返回 `next_draft_input` + + +### 10.2 decode / verify + +关键位置: + +- `forward_batch_generation()`:约 `698-722` +- `verify()`:约 `724-780` + +流程: + +1. `draft_worker.draft()` 生成 `EagleVerifyInput` +2. `verify()` 内部构造 verify forward batch +3. target 执行 verify 前向 +4. draft 再做 `_draft_extend_for_decode()` + + +### 10.3 spec v2 的一个核心特征 + +它不再直接围绕 `ScheduleBatch` 做所有 speculative 逻辑,而是更偏向: + +- `ModelWorkerBatch` +- `next_draft_input` +- `future_indices` +- overlap plan stream + +这也是它和 v1 最大的结构差异。 + + +## 11. attention backend 如何感知 speculative + +最典型的文件: + +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + +### 11.1 `init_forward_metadata()` + +关键位置: + +- `aiter_backend.py` 约 `435+` + +这段代码是理解 speculative attention 路径最关键的入口之一。 + +它不是简单区分: + +- decode +- extend + +而是细分: + +- `decode_or_idle` +- `draft_extend` +- `target_verify` +- 普通 extend + + +### 11.2 `draft_extend` + +关键位置: + +- `aiter_backend.py` 约 `526-606` + +特点: + +- 通过 `spec_info.generate_attn_arg_prefill(...)` + 来生成 draft extend 所需的 attention 参数 +- 对 MLA 与非 MLA 路径分别处理 + + +### 11.3 `target_verify` + +关键位置: + +- `aiter_backend.py` 约 `607+` + +特点: + +- 不依赖普通 extend 的 `extend_seq_lens` +- 直接根据: + - `spec_info.draft_token_num` + - `forward_batch.seq_lens` + +构造 verify 所需的: + +- `qo_indptr` +- `kv_indptr` +- `kv_indices` + +这是后续排查 `ATOM plugin verify` 问题时最值得对照的一段。 + + +## 12. v1 与 v2 的差异总结 + +### 12.1 v1 + +特点: + +- 更偏 `ScheduleBatch` +- speculative worker 逻辑更集中在一个大类里 +- 以串行 orchestrate 为主 + + +### 12.2 v2 + +特点: + +- 依赖 overlap scheduler +- 更偏 `ModelWorkerBatch` +- 引入 `next_draft_input` +- 更明显地区分 draft worker 与 orchestrator worker +- 更容易做 plan stream / overlap + + +### 12.3 对调试的实际影响 + +如果你调试 speculative 问题,一定要先分清: + +- 当前是 v1 还是 v2 +- 当前 `model_worker.forward_batch_generation(...)` + 收到的是 `ScheduleBatch` 还是 `ModelWorkerBatch` + +否则很容易误判字段来源和生命周期。 + + +## 13. 一份推荐阅读顺序 + +如果之后需要重新从头理解 SGLang speculative decoding,建议按下面顺序读: + +1. `server_args.py` + - 看 speculative 参数、NEXTN 规范化、spec v2 开关 +2. `spec_info.py` + - 看算法枚举和 worker 工厂 +3. `scheduler.py` + - 看 `maybe_init_draft_worker()` / `init_model_worker()` / `run_batch()` +4. `tp_worker.py` + - 看 target 与 draft `ModelConfig` 的分流 +5. `model_config.py` + - 看 `_config_draft_model()` +6. `deepseek_nextn.py` + - 看 `DeepseekV3ForCausalLMNextN` 到底长什么样 +7. `eagle_worker.py` + - 看传统 speculative v1 主流程 +8. `eagle_worker_v2.py` + - 看 overlap/spec v2 主流程 +9. `schedule_batch.py` + - 看 `ScheduleBatch -> ModelWorkerBatch` +10. `forward_batch_info.py` + - 看 `ForwardMode` 和 `ForwardBatch` +11. `aiter_backend.py` + - 看 speculative attention metadata 怎么初始化 + + +## 14. 对 ATOM plugin 调试的启发 + +这份背景知识对 `ATOM + SGLang plugin` 的调试最直接的启发有三点: + +### 启发 1 + +不能只盯 model class,还要盯: + +- `ServerArgs` +- `ModelConfig` +- `Scheduler` +- `TpModelWorker` + +因为 draft/target 的分流在这些层已经决定了。 + + +### 启发 2 + +如果 plugin 只覆盖了: + +- `DeepseekV3ForCausalLM` + +但没有覆盖: + +- `DeepseekV3ForCausalLMNextN` + +那么最终一定会形成: + +- target 走 plugin +- draft 走 upstream + +的混合运行形态。 + + +### 启发 3 + +如果 attention backend 只按: + +- decode +- extend + +粗暴分流,而没有补: + +- `target_verify` +- `draft_extend` + +这类 speculative 专有 metadata 路径, +那么在 speculative 模式下一定迟早会在 verify / draft_extend 里出错。 + + +## 15. 最终总结 + +SGLang 的 speculative decoding 并不是一个局部 feature,而是一套完整的运行时体系: + +- 在配置层决定算法和 draft model 语义 +- 在 model config 层改写 draft 架构 +- 在 scheduler 层切换到 speculative orchestrator +- 在 worker 层维护 target / draft 协作 +- 在 batch 层用三层结构管理状态转换 +- 在 attention backend 层按 `ForwardMode` 和 `SpecInput` + 细化 metadata 初始化 + +如果后续要把 `ATOM MTP` 真正接到 plugin 路径上,最重要的不是先改单个 kernel, +而是先把这张图看清楚: + +- 谁负责 target +- 谁负责 draft +- 谁负责调度 +- 哪些数据结构在层层转换 +- speculative 特有的 `target_verify` / `draft_extend` + 在 attention/backend 层是如何被建模的 + +只有在这个架构认知稳定之后,后面的接入和调试才会高效。 diff --git a/work_log/MTP/2026-04-08-vllm-continuous-batching.md b/work_log/MTP/2026-04-08-vllm-continuous-batching.md new file mode 100644 index 000000000..2e643b89e --- /dev/null +++ b/work_log/MTP/2026-04-08-vllm-continuous-batching.md @@ -0,0 +1,1214 @@ +# 2026-04-08 vLLM Continuous Batching 原理与源码位置笔记 + +## 文档目的 + +这份文档用于系统梳理 `vLLM` 中与 `continuous batching` 相关的核心机制, +重点回答下面几个问题: + +- `continuous batching` 到底是什么,和静态 batch 有什么本质差异 +- `scheduler` 是如何把多个 request 组装成一次 step 的执行 batch 的 +- 一次 step 之后,模型输出是如何再放回各个 request 的 +- 为什么 vLLM 的 batch 不是传统训练里那种规则的 `B x L` +- `prefill / decode / chunked prefill / speculative decode` 在这个框架下是怎样统一的 +- 如果要顺着源码看,应该优先看哪些代码位置 + +本文尽量从三个维度同时展开: + +- 系统设计 +- 张量 shape +- 源码入口 + +方便后续复盘、调试和与其他推理框架做对比。 + + +## 一句话总结 + +`vLLM continuous batching` 的核心不是“把很多请求 pad 成一个固定 `B x L` 大矩阵”, +而是: + +- 每个 step 动态决定“每个 request 这一步前进多少 token” +- 把这些 token 展平成一个按 token 计数的 flat batch +- 用 `query_start_loc / seq_lens / block_tables / slot_mapping` 等元数据告诉 GPU: + - 每个 token 属于哪个 request + - 它在该 request 中的位置是多少 + - 它应该读写哪一段 KV cache +- step 结束后,再通过 `req_id_to_index` 把输出准确拆回每个 request + +可以把它概括成: + +```text +requests + -> scheduler 决定本步每个 request 的 n_i + -> 组装成 flat token batch, 总 token 数 T = sum(n_i) + -> GPU forward / sample + -> 用 req_id_to_index 把输出拆回 request +``` + + +## 版本说明 + +本文主要覆盖两条线: + +- `v1 / current main` 风格实现 +- `v0` 旧版 `LLMEngine` / `SequenceGroup` 风格实现 + +需要注意: + +- `v1` 是当前更值得优先看的主线 +- `v0` 仍然有很强的参考价值,因为很多文章、issue、历史讨论都仍然沿用 + `SequenceGroup`、`SchedulerOutputs` 这套命名 +- 本文提到的源码位置与行号,基于 `2026-04-08` 抓取的 upstream 快照, + 后续可能会轻微漂移 + + +## 1. 为什么 continuous batching 不是静态 batching + +### 1.1 静态 batching 的直觉 + +训练或普通离线推理中,大家更熟悉的是: + +- 给定一批序列 +- pad 到同一个长度 +- 形成一个规则张量 + +例如: + +```text +input_ids.shape = [B, L] +attention_mask.shape = [B, L] +``` + +这种做法的假设是: + +- 这一批样本一起开始 +- 一起执行 +- 一起结束 + + +### 1.2 在线 serving 的问题 + +在线服务时,请求并不是同时到达,也不会同时结束。 + +典型情况是: + +- 某些 request 还在做长 prompt 的 prefill +- 某些 request 已经进入 decode,每步只需要前进 1 个 token +- 某些 request 刚结束 +- 某些新 request 又刚进入系统 + +如果还坚持用静态 batch,就会遇到: + +- 等待新 request 凑满 batch,TTFT 变差 +- 某个 request 提前结束后,batch 中留下空洞 +- prompt 很长的 request 会拖慢所有其他 request + + +### 1.3 continuous batching 的本质 + +所以 vLLM 的选择不是“固定一批 request 一起跑到结束”,而是: + +- 每个调度 step 都重新看当前系统中的 request +- 决定本步哪些 request 参与 +- 决定每个 request 本步前进多少 token +- step 结束后,再立刻重组下一轮 batch + +因此,batch 是“连续流动”的。 + +这也是 `continuous batching` 这个名字的真正含义。 + + +## 2. vLLM 视角下一个 request 的核心状态 + +在 vLLM 中,理解 request 的关键不是先区分“prefill 还是 decode”, +而是先看下面几个状态量。 + +### 2.1 最重要的两个量 + +- `all_token_ids` + - 当前这个 request 已知的完整 token 序列 + - 包括 prompt token,也包括已经生成出来但可能尚未被下一轮 compute 的 token +- `num_computed_tokens` + - `all_token_ids` 中已经真正做过 forward、对应 KV 已经落到 cache 的前缀长度 + +于是: + +- 如果 `num_computed_tokens = 0`,说明 prompt 还没 prefill +- 如果 `num_computed_tokens < len(all_token_ids)`,说明还有 backlog 没算 +- decode 阶段常见情况是: + - 上一轮 sample 出了 1 个新 token + - 下一轮需要把这 1 个 token 真正送进模型计算 + - 所以通常 backlog 是 1 + + +### 2.2 统一 prefill / decode 的关键观察 + +从 scheduler 角度,并不存在一个特别刚性的: + +- “prefill phase” +- “decode phase” + +更接近的真实逻辑是: + +- 对每个 request,看它还有多少 token 没被 compute +- 本轮决定从这些 backlog 中取多少 token 来执行 + +因此: + +- 新 request 的 backlog 通常很大,对应 prefill +- 老 request 的 backlog 常常只有 1,对应 decode +- chunked prefill 只是“长 request 的 backlog 一次不要全吃完” + +也就是说,`prefill / decode` 更像是同一调度框架下的两种常见形态。 + + +### 2.3 KV cache 也是 request 状态的一部分 + +除了 token 序列本身,每个 request 还绑定: + +- KV cache block +- block table +- 对应的 slot mapping + +这决定了: + +- decode 时虽然本轮可能只新输入 1 个 token +- 但模型仍然能通过 KV cache 读取全部历史上下文 + +所以一个 request 的有效状态并不只是 token ids,而是: + +```text +request state + = token sequence + + num_computed_tokens + + sampling / stop state + + KV cache mapping + + (可选)LoRA / multimodal / structured output 状态 +``` + + +## 3. 一个 batch 的“实质性内容”到底是什么 + +这是最容易误解的地方。 + +### 3.1 从 tokenizer 语义看 + +一个 token 最原始确实就是一个 vocab id,也就是一个整数。 + +例如: + +```text +"hello" -> 15496 +``` + +因此在输入层面,`input_ids` 的每个元素确实就是“一个数字”。 + + +### 3.2 从 GPU 执行看 + +但真正送进 GPU 跑一次 step,远远不只有 `input_ids`。 + +至少还需要: + +- `input_ids` +- `positions` +- `query_start_loc` +- `seq_lens` +- `block_tables` +- `slot_mapping` +- `logits_indices` + +特殊情况下还会有: + +- `inputs_embeds` +- multimodal encoder 相关输入 +- LoRA metadata +- speculative decode 的 draft token 相关索引 +- structured output grammar bitmask + +所以如果问: + +> 一个 batch 的实质性 token,是不是仅仅是简单的 input_id? + +答案是: + +- 对“token 身份”来说,最原始确实是 `input_id` +- 对“一次 forward 的完整执行语义”来说,绝对不够 + +因为模型还必须知道: + +- 这个 token 属于哪个 request +- 它在该 request 里的绝对位置是多少 +- 它应该从 KV cache 的哪一段读取历史上下文 + + +### 3.3 这个 `input_id` 在 GPU 上吗 + +在真正执行时,是的。 + +更准确地说: + +- request 和 scheduler 主要在 CPU 侧维护高层状态 +- 但进入 worker / model runner 后,本步需要用到的 + `input_ids / positions / query_start_loc / block_tables / slot_mapping` + 会被放入 GPU buffer +- 然后做 embedding lookup,进入 transformer forward + +所以: + +- 逻辑上的 token id 一开始常出现在 CPU 侧 +- 本步执行用到的 `input_ids` 会进入 GPU +- 进入模型后,它很快会被 embedding 成一个向量 + +例如: + +```text +token_id: scalar + -> embedding lookup + -> hidden vector: [hidden_size] +``` + +如果本轮总共执行 `T` 个 token,那么 embedding 后大致就是: + +```text +[T, hidden_size] +``` + + +## 4. scheduler 到底在做什么 + +### 4.1 核心目标 + +scheduler 的工作不是“把所有 request pad 成一个矩阵”,而是: + +- 从 `waiting / running` 队列里挑 request +- 决定每个 request 本步前进多少 token +- 保证不超出资源预算 +- 必要时做 preemption +- 产出 worker 能执行的调度结果 + + +### 4.2 主要约束 + +在 `v1` 中,最重要的两个约束是: + +- `max_num_seqs` + - 本步最多同时挂多少个 request +- `max_num_batched_tokens` 或 `max_num_scheduled_tokens` + - 本步总共最多前进多少 token + +此外还会考虑: + +- model max length +- encoder 计算预算(多模态) +- LoRA 同批数量限制 +- KV cache block 是否足够 +- prefix cache / remote KV / async loading 状态 + + +### 4.3 调度的高层流程 + +`v1` 的 `Scheduler.schedule()` 大致可以概括成: + +1. 先尝试调度 `running` request +2. 再尝试从 `waiting` 里吸入新 request +3. 对每个 request 决定本步的 `n_i` +4. 维护 `token_budget` +5. 如果 block 不够或约束冲突,必要时 preempt 某些 request +6. 输出 `SchedulerOutput` + +一个很关键的设计点是: + +> scheduler 关心的是 “本步每个 request 前进多少 token” + +而不是: + +> “这个 request 属于 prefill 还是 decode 类别” + + +### 4.4 chunked prefill 是怎样融进去的 + +长 prompt 的 request,如果一次全吃完会把 token budget 吃光, +拖累其他 request。 + +所以 vLLM 会在需要时把它拆成多步: + +- 本轮只 prefill prompt 的一部分 +- 剩下的下轮再继续 + +因此一个 request 可以出现: + +- 还在 prefill chunk 中 +- 但同时其他 request 已经在 decode + +这正是 continuous batching 最典型的混合场景。 + + +## 5. scheduler 输出的关键数据结构 + +在 `v1` 中,scheduler 输出的核心抽象是: + +- `NewRequestData` +- `CachedRequestData` +- `SchedulerOutput` + +可以把这三者理解成: + +- `NewRequestData` + - 首次进入 worker 的 request,要发送完整初始化数据 +- `CachedRequestData` + - worker 已经缓存过的 request,只发送增量信息 +- `SchedulerOutput` + - 这一步所有调度决策的总封装 + + +### 5.1 `NewRequestData` + +它通常包含: + +- `req_id` +- `prompt_token_ids` +- `sampling_params` +- `pooling_params` +- `block_ids` +- `num_computed_tokens` +- `lora_request` +- `prefill_token_ids`(v2 model runner 相关) + +也就是说,新 request 第一次进入 worker 时,需要把足够多的静态信息发过去, +让 worker 端建立自己的 request cache。 + + +### 5.2 `CachedRequestData` + +这个结构是 continuous batching 很重要的一环,因为它体现了: + +> worker 对 request 状态是“长期缓存”的,而不是每 step 重建。 + +典型字段有: + +- `req_ids` +- `resumed_req_ids` +- `new_token_ids` +- `all_token_ids` +- `new_block_ids` +- `num_computed_tokens` +- `num_output_tokens` + +其中最关键的思想是: + +- 对已经在 worker 里的 request,不重复发送整条 request +- 只发送变化的部分 + +这能显著减少调度端和 worker 之间的通信成本。 + + +### 5.3 `SchedulerOutput` + +最重要的字段有: + +- `scheduled_new_reqs` +- `scheduled_cached_reqs` +- `num_scheduled_tokens: dict[req_id, int]` +- `total_num_scheduled_tokens` +- `scheduled_spec_decode_tokens` +- `scheduled_encoder_inputs` +- `finished_req_ids` + +其中: + +- `num_scheduled_tokens` 是整轮 step 的核心 +- 它表达的是: + - 这个 request 这一步要前进几个 token + +如果把本轮调度到了 `B` 个 request,则: + +```text +num_scheduled_tokens: {req_id_1: n_1, ..., req_id_B: n_B} +T = n_1 + ... + n_B +``` + +这里: + +- `B` 是 request 数 +- `T` 是 token 数 + +vLLM 后续执行更偏向围绕 `T` 展开,而不是围绕规则的 `B x L`。 + + +## 6. 真正执行时,batch 的 shape 长什么样 + +这是理解 vLLM 最关键的一节。 + +### 6.1 不是 `[B, L]`,而是 token-flat `[T]` + +假设本轮有 `B` 个 request,第 `i` 个 request 本轮前进 `n_i` 个 token。 + +则: + +```text +T = sum_i n_i +``` + +执行时,最核心的输入通常是: + +- `input_ids`: `[T]` +- `positions`: `[T]` +- `query_start_loc`: `[B + 1]` +- `seq_lens`: `[B]` + +为了 CUDA graph 或执行约束,vLLM 里还常会有 padding 后版本: + +- `T_pad` +- `B_pad` + +于是实际 buffer 常是: + +- `input_ids`: `[T_pad]` +- `positions`: `[T_pad]` +- `seq_lens`: `[B_pad]` + + +### 6.2 `query_start_loc` 是什么 + +`query_start_loc` 是每个 request 在扁平 token buffer 中的边界。 + +如果: + +```text +n = [1, 1, 2, 6] +``` + +则: + +```text +query_start_loc = [0, 1, 2, 4, 10] +``` + +含义是: + +- 第 0 个 request 用 `input_ids[0:1]` +- 第 1 个 request 用 `input_ids[1:2]` +- 第 2 个 request 用 `input_ids[2:4]` +- 第 3 个 request 用 `input_ids[4:10]` + +这就是: + +- 一个大 flat token buffer +- 加一个分段索引数组 + +共同表达 ragged batch 的典型做法。 + + +### 6.3 decode 为何也能放进这个框架 + +decode request 在本轮通常只前进 1 个 token,所以常见: + +```text +n_i = 1 +``` + +那它在扁平 batch 里也就只占一个元素。 + +例如: + +```text +input_ids = [r1_new, r2_new, p0, p1, p2, p3] +``` + +这里前两个是 decode token,后四个是某个 prefill request 的 prompt chunk。 + +看起来 decode token 很“短”,但它并不缺上下文,因为上下文来自: + +- `seq_lens` +- `block_tables` +- `slot_mapping` +- KV cache + + +### 6.4 还有哪些重要 shape + +除了上面几个,attention 执行时还非常依赖: + +- `block_tables` + - 近似可以看成:每个 KV cache group 一份 + - 形状常见近似为 `[B_pad, max_num_blocks]` +- `slot_mapping` + - 近似为每个 token 映射到哪个 KV slot + - 常见近似为 `[T_pad]` + +因此,从 GPU 视角看,一个 batch 更接近: + +```text +flat token payload + + per-request segmentation metadata + + KV cache address metadata +``` + +而不是简单的 `input_ids` 矩阵。 + + +## 7. 一次 step 的完整生命周期 + +在 `v1` 中,可以把一次 step 概括为: + +1. `schedule()` +2. `execute_model(...)` +3. `update_from_output(...)` +4. `OutputProcessor.process_outputs(...)` + +下面按顺序拆开。 + + +### 7.1 `schedule()` + +scheduler 产生: + +- 哪些 request 参与本轮 +- 每个 request 本轮前进多少 token +- 新 request / cached request 的增量更新数据 + +同时,vLLM 还有一个很值得注意的设计: + +- request 被 schedule 到后,会先把 `num_computed_tokens` 往前推进 +- 这样它可以在下一轮继续被及时调度 +- 如果后面 speculative token 有拒绝,再在 `update_from_output()` 里回调修正 + +这说明: + +- 调度状态和最终 sample 结果之间不是完全同步的 +- 某些统计量会“先乐观推进,再按输出修正” + + +### 7.2 `execute_model(...)` + +worker 侧收到 `SchedulerOutput` 后,会做: + +- `add_requests()` + - 初始化首次进入 worker 的 request +- `update_requests()` + - 更新已有 request 的 block / token 等状态 +- `prepare_inputs()` + - 组装本轮 flat token batch +- `prepare_attn()` + - 生成 attention metadata +- 执行模型 forward +- sample token / 或做 pooling + +执行完成后返回 `ModelRunnerOutput`。 + + +### 7.3 `ModelRunnerOutput` + +这是“从 GPU 结果回到 scheduler”的关键桥梁。 + +核心字段可以理解成: + +- `req_ids`: `[B]` +- `req_id_to_index: {req_id -> batch_idx}` +- `sampled_token_ids`: `list[list[int]]` +- `logprobs` +- `prompt_logprobs_dict` +- `pooler_output` + +其中最关键的是: + +- `req_id_to_index` +- `sampled_token_ids` + +因为 worker 为了执行效率可能重排 request 顺序,所以 scheduler 回填时不能假设: + +- “第 0 个输出一定属于第 0 个 request” + +而必须显式做: + +```text +idx = req_id_to_index[req_id] +generated = sampled_token_ids[idx] +``` + + +### 7.4 `update_from_output(...)` + +这一步负责: + +- 根据 `req_id_to_index` 找到每个 request 对应的输出 +- 把 `sampled_token_ids[idx]` 回填到 request 状态 +- 检查 stop / eos / length +- 处理 speculative decode 的接受 / 拒绝 +- 必要时释放 request 的 KV cache +- 产出 `EngineCoreOutput` + +这里有一个非常重要的概念区分: + +- `n_i` + - 本轮这个 request 被安排去“计算”的 token 数 +- `g_i` + - 本轮真正“生成出来并回给请求”的 token 数 + +这两个量不一定相等。 + +典型例子: + +- chunked prefill 时,`n_i > 0`,但 `g_i = 0` +- 普通 decode 时,通常 `n_i = 1`,`g_i = 1` +- speculative decode 时,可能 `n_i = 1 + k`,而 `g_i` 可以大于 1 + + +### 7.5 `OutputProcessor.process_outputs(...)` + +这一步负责从 engine 内部输出变成用户能看到的 `RequestOutput`。 + +主要工作有: + +- detokenize +- stop string 检查 +- logprobs 处理 +- 组装 `RequestOutput` + +因此完整链路是: + +```text +SchedulerOutput + -> ModelRunnerOutput + -> EngineCoreOutput + -> RequestOutput +``` + + +## 8. 例子一:3 个 request,4 个 step + +下面给一个不带 speculative decode 的完整 toy example。 + +假设配置: + +- `max_num_seqs = 3` +- `max_num_batched_tokens = 6` +- 开启 `chunked prefill` + +4 个请求依次到达: + +- `R1 = [11,12,13,14,15]` +- `R2 = [21,22]` +- `R3 = [31,32,33,34]` +- `R4 = [41,42,43]` + +初始状态: + +```text +R1: all=[11,12,13,14,15], comp=0 +R2: all=[21,22], comp=0 +``` + + +### Step 0 + +scheduler 选择: + +- `R1` 前进 4 个 prompt token +- `R2` 前进 2 个 prompt token + +于是: + +```text +num_scheduled_tokens = {R1: 4, R2: 2} +B = 2 +T = 6 +``` + +worker 侧可能重排为: + +```text +req_ids = [R2, R1] # shape [2] +num_scheduled_tokens = [2, 4] # shape [2] +query_start_loc = [0, 2, 6] # shape [3] + +input_ids = [21,22, 11,12,13,14] # shape [6] +positions = [0,1, 0,1,2,3] # shape [6] +seq_lens = [2,4] # shape [2] +``` + +假设这一轮输出: + +- `R2` prompt 已结束,sample 到首个生成 token `23` +- `R1` 还没结束 prefill,没有生成 token + +则: + +```text +req_id_to_index = {R2: 0, R1: 1} +sampled_token_ids = [[23], []] +``` + +回填后: + +```text +R1: all=[11,12,13,14,15], comp=4 +R2: all=[21,22,23], comp=2 +``` + +注意: + +- `R2` 的 `23` 已追加到 `all_token_ids` +- 但 `comp=2`,因为本轮真正算进 KV 的还是原 prompt 的 2 个 token +- `23` 会在下一轮真正参与 decode compute + + +### Step 1 + +此时 `R3` 到达。 + +现在系统中: + +- `R1` 还差 1 个 prompt token +- `R2` 要 decode 它的 `23` +- `R3` 是新 request,需要 prefill + +于是 scheduler 可以在同一步混合调度: + +```text +num_scheduled_tokens = {R1: 1, R2: 1, R3: 4} +B = 3 +T = 6 +``` + +执行 batch: + +```text +req_ids = [R1, R2, R3] # shape [3] +query_start_loc = [0, 1, 2, 6] # shape [4] + +input_ids = [15, 23, 31,32,33,34] # shape [6] +positions = [4, 2, 0,1,2,3] # shape [6] +seq_lens = [5, 3, 4] # shape [3] +``` + +假设输出: + +```text +sampled_token_ids = [[101], [24], [35]] +``` + +回填后: + +```text +R1: all=[11,12,13,14,15,101], comp=5 +R2: all=[21,22,23,24], comp=3 +R3: all=[31,32,33,34,35], comp=4 +``` + +这一步非常重要,因为它体现了: + +- `R1` 还在补最后一段 prefill +- `R2` 已在 decode +- `R3` 是新 request 的整段 prefill + +三者可以在同一个 step 里并存。 + + +### Step 2 + +现在三者都进入正常 decode: + +```text +num_scheduled_tokens = {R1:1, R2:1, R3:1} +B = 3 +T = 3 + +input_ids = [101, 24, 35] # shape [3] +query_start_loc = [0, 1, 2, 3] # shape [4] +``` + +假设输出: + +```text +sampled_token_ids = [[102], [2], [36]] +``` + +若 `2` 是 EOS,则: + +- `R2` 完成 +- `R2` 的 KV block 可被释放 +- `R1`、`R3` 继续保留 + + +### Step 3 + +此时 `R4` 到达,于是可以立刻填补空位: + +```text +num_scheduled_tokens = {R1:1, R3:1, R4:3} +B = 3 +T = 5 + +input_ids = [102, 36, 41,42,43] # shape [5] +query_start_loc = [0, 1, 2, 5] # shape [4] +``` + +这就是 continuous batching 的最直观体现: + +- 老 request 结束后立即移出 +- 新 request 马上补进来 +- 系统不是“整批结束再换下一批”,而是每一步都在流动 + + +## 9. 例子二:为什么 `n_i` 不等于 `g_i` + +很多人第一次看时会默认: + +- scheduler 安排这个 request 算 4 个 token +- 那它就应该返回 4 个 token + +这在 vLLM 中并不成立。 + +### 9.1 chunked prefill 场景 + +假设一个长 prompt request: + +```text +prompt = [p0, p1, p2, p3, p4, p5, p6, p7] +num_computed_tokens = 0 +``` + +本轮只给它 4 个 token budget: + +```text +n_i = 4 +input_ids = [p0, p1, p2, p3] +``` + +如果 prompt 还没 prefill 完,则这轮: + +```text +g_i = 0 +``` + +也就是: + +- 本轮确实算了 4 个 token +- 但没有新生成 token 回给用户 + + +### 9.2 普通 decode 场景 + +如果一个 request 已经完成 prefill,只差 decode: + +```text +n_i = 1 +g_i = 1 +``` + +这是最常见、也最容易理解的情况。 + + +### 9.3 speculative decode 场景 + +如果开启 speculative decode,情况会变成: + +- 本轮可能先有若干 draft token +- target verify 后可能一次接受多个 token + +于是可能出现: + +```text +n_i = 1 + k +g_i = m +``` + +其中: + +- `k` 是 draft 相关的额外计算 +- `m` 是最终接受并回填的 token 数 +- `m` 可以大于 1 + + +## 10. 例子三:speculative decode 的 shape 直觉 + +假设某个 request 本轮有 3 个 draft token: + +```text +scheduled_spec_decode_tokens = { + R1: [501, 502, 503] +} +``` + +worker 执行后,假设 target verify: + +- 接受了前 2 个 draft token +- 然后再给出 1 个新的 target token + +则返回到 scheduler 的 `generated_token_ids` 可能近似为: + +```text +sampled_token_ids[idx_of_R1] = [501, 502, 900] +``` + +此时: + +- `num_draft_tokens = 3` +- `num_accepted = len(generated_token_ids) - 1 = 2` +- `num_rejected = 3 - 2 = 1` + +也就是说: + +- 被接受的 draft token 会直接作为 output 回填 +- 被拒绝的 draft token 要把之前乐观推进的 `num_computed_tokens` + 再修正回来 + +这也是为什么 scheduler 与 output update 之间会有一个“先推进、后修正”的配合。 + + +## 11. v0 和 v1 的关系 + +如果你看旧版文章,经常会看到: + +- `Sequence` +- `SequenceGroup` +- `ScheduledSequenceGroup` +- `RequestOutput.from_seq_group(...)` + +这主要是 `v0` 风格。 + +### 11.1 v0 的回填方式 + +旧版 `LLMEngine.step()` 的高层流程大致是: + +```text +scheduler.schedule() + -> model_executor.execute_model(...) + -> _process_model_outputs(...) + -> RequestOutput.from_seq_group(...) +``` + +这里的回填更偏向: + +- 先把 sampler output 按 sequence group 拆好 +- 再依次更新每个 `SequenceGroup` + + +### 11.2 v1 的回填方式 + +`v1` 更偏向 request-centric: + +- scheduler 输出 `num_scheduled_tokens` +- worker 返回 `req_id_to_index` +- scheduler 用 `req_id_to_index` 查表回填 + +所以: + +- `v0` 更像“按 sequence group 顺序回填” +- `v1` 更像“按 req_id 显式映射回填” + +但本质是一样的: + +> 执行 batch 在 GPU 侧可以重排、压平、做优化; +> 但 step 结束后必须有一套稳定映射,把输出放回正确的 request。 + + +## 12. 推荐源码入口 + +下面给出一份更适合顺着看的源码索引。 + +### 12.1 v1 主线 + +#### 1. 调度输出结构 + +- `vllm/v1/core/sched/output.py` + - `NewRequestData` + - `CachedRequestData` + - `SchedulerOutput` + +建议先看它,因为它定义了 scheduler 究竟在给 worker 发送什么。 + + +#### 2. scheduler 主入口 + +- `vllm/v1/core/sched/scheduler.py` + - `Scheduler.schedule()`,约 `348` + - `_make_cached_request_data()`,约 `1055` + - `update_from_output()`,约 `1302` + +这是最核心的一组函数。 + +尤其推荐按下面顺序看: + +1. `schedule()` +2. `_make_cached_request_data()` +3. `update_from_output()` + + +#### 3. engine step + +- `vllm/v1/engine/core.py` + - `EngineCore.step()`,约 `380` + +这能把高层链路串起来: + +```text +schedule + -> execute_model + -> update_from_output +``` + + +#### 4. worker 侧 batch 组装 + +- `vllm/v1/worker/gpu/model_runner.py` + - `add_requests()`,约 `612` + - `update_requests()`,约 `657` + - `prepare_inputs()`,约 `667` + +如果你最关心 shape,`prepare_inputs()` 是必须看的。 + +它直接体现: + +- `num_scheduled_tokens -> query_start_loc` +- `flat input_ids / positions` +- `seq_lens` +- `cu_num_logits` +- speculative decode 相关展开 + + +#### 5. 输出结构 + +- `vllm/v1/outputs.py` + - `SamplerOutput` + - `ModelRunnerOutput` + +这决定了从 GPU 回 scheduler 时到底带了哪些数据。 + + +#### 6. engine 内部输出与最终用户输出 + +- `vllm/v1/engine/__init__.py` + - `EngineCoreOutput` + - `EngineCoreOutputs` +- `vllm/v1/engine/output_processor.py` + - `RequestState.make_request_output()`,约 `269` + - `OutputProcessor.process_outputs()`,约 `572` + +这部分更偏“回填后的用户接口层”。 + + +### 12.2 旧版 v0 参考线 + +- `vllm/engine/llm_engine.py` + - `_process_model_outputs(...)`,约 `510` + - `step()`,约 `557` + +适合在下面两种情况下参考: + +- 你看到旧文档 / issue 还在讲 `SequenceGroup` +- 你想对照理解 vLLM 是怎样从旧结构演化到 `v1` 的 + + +## 13. 看源码时建议抓住的 5 个问题 + +如果你在调试 continuous batching,建议始终围绕下面几个问题读代码。 + +### 13.1 本轮到底调度了哪些 request + +看: + +- `num_scheduled_tokens` +- `scheduled_new_reqs` +- `scheduled_cached_reqs` + + +### 13.2 每个 request 本轮前进了多少 token + +看: + +- `n_i = num_scheduled_tokens[req_id]` + + +### 13.3 扁平 batch 的边界在哪里 + +看: + +- `query_start_loc` +- `seq_lens` + + +### 13.4 输出怎么知道属于哪个 request + +看: + +- `req_id_to_index` +- `sampled_token_ids[idx]` + + +### 13.5 request 状态什么时候推进,什么时候修正 + +看: + +- schedule 后 `num_computed_tokens` 的推进 +- speculative decode 拒绝后在 `update_from_output()` 中的修正 + + +## 14. 常见误区 + +### 14.1 “一个 batch 就是一个 `input_ids` 矩阵” + +不对。 + +vLLM 更接近: + +```text +input_ids[T] + + positions[T] + + query_start_loc[B+1] + + seq_lens[B] + + block_tables + + slot_mapping +``` + + +### 14.2 “decode 只输入 1 个 token,所以计算很简单” + +不对。 + +decode 本轮虽然只新输入 1 个 token id,但它会读取整条历史序列对应的 KV cache, +真正的上下文并没有消失。 + + +### 14.3 “本轮调度了几个 token,就一定返回几个 token” + +不对。 + +要始终区分: + +- 计算的 token 数 `n_i` +- 生成并回填的 token 数 `g_i` + + +### 14.4 “prefill 和 decode 是两套完全不同的调度器” + +不对。 + +在 vLLM 的设计里,它们更像是同一个“按 backlog 前进”的调度框架下的不同常见情形。 + + +## 15. 最终总结 + +从实现上看,`vLLM continuous batching` 的关键可以归纳成下面几句话: + +- scheduler 的核心决策单位不是“固定长度序列”,而是“本步每个 request 前进多少 token” +- worker 的执行核心不是规则的 `B x L`,而是 flat token batch `[T]` +- `query_start_loc / seq_lens / block_tables / slot_mapping` 决定了这些 token 如何映射回各自 request 与 KV cache +- step 结束后,`req_id_to_index` 负责把输出准确拆回 request +- `n_i` 与 `g_i` 不一定相等,这一点对理解 chunked prefill 与 speculative decode 非常重要 + +所以,continuous batching 真正连续流动的不是“一个静态矩阵”,而是: + +- request 集合在流动 +- 每步前进 token 数在流动 +- GPU token batch 的形状在流动 +- 完成与新加入的 request 在每个 step 都会重新重组 + +这正是它能在在线 serving 中同时兼顾: + +- 吞吐 +- 低延迟 +- 动态请求混合 + +的根本原因。 diff --git a/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md b/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md new file mode 100644 index 000000000..3250f331d --- /dev/null +++ b/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md @@ -0,0 +1,734 @@ +# SGLang CUDAGraph、Prefill/Decode 与 Attention Metadata 说明 + +## 文档目的 + +这篇文档专门回答下面几个问题: + +- `CUDAGraph` 在 `SGLang` 里到底固定了什么 +- 为什么 `decode` 更适合做 `CUDAGraph` +- 为什么普通 `prefill / extend` 的 attention 很难复用同一张 graph +- `SGLang` 在 `decode` 阶段是怎样做 `capture / replay` 的 +- `attention backend` 的 `ForwardMetadata` 在 graph capture / replay 中扮演什么角色 +- 在 `ATOM plugin` 模式下,哪些 graph-only bug 容易出现,为什么 + + +## 一句话结论 + +最重要的结论先说: + +- `CUDAGraph` 固定的不是“某个 Python 函数调用”,而是一整段已经展开好的 CUDA 执行脚本 +- `decode` 更适合 graph,不是因为它没有 metadata,而是因为它的 **query 结构、token 数、kernel 形状、workspace 结构** 更稳定 +- 普通 `prefill / extend` 难 graph,不是因为 kernel 不能读不同的 metadata,而是因为 metadata 往往不只是“输入数据”,而是会影响: + - 走哪条代码路径 + - 中间 tensor 分配多大 + - gather 后张量有多长 + - workspace 形状是什么 + - 最终 launch 的 kernel 形态是什么 +- 换句话说: + - `decode` 下,metadata 更像 **kernel 参数** + - `prefill` 下,metadata 更像 **图结构控制器** + + +## 1. CUDAGraph 真正固定的是什么 + +很多人第一次接触 `CUDAGraph` 时,会误以为它只是“把一次 forward 缓存起来”。 + +更准确地说,`CUDAGraph` capture 固定的是: + +- 这次 forward 里实际 launch 了哪些 CUDA kernel +- kernel 的调用顺序 +- 每个 kernel 看到的 tensor shape / stride +- 这些 tensor 和 workspace 的内存地址 +- Python 层已经展开后的控制流分支 + +因此: + +- **tensor 的值**可以变 +- 但 **shape / 地址 / 分支 / launch 计划** 最好不要变 + +可以把它想成: + +- eager 模式像“每次现写一遍执行计划” +- cuda graph 像“录下这次执行计划,以后按原样回放” + + +## 2. 三个最容易混淆的量:`raw_bs`、`bs`、`num_tokens` + +理解 graph 之前,必须先区分三个量: + +- `raw_bs` + - 当前真实 batch 里有多少个 request +- `bs` + - 当前 replay 选中的 graph bucket 大小 +- `num_tokens` + - 这次真正传给很多 layer 的 token 数 + +它们经常不相等。 + +### 2.1 `raw_bs` + +这是 scheduler 当前真实调度出来的 request 数。 + +例如: + +- 真实只有 3 个 request 要做 decode +- 那么 `raw_bs = 3` + +### 2.2 `bs` + +这是 graph 系统为了复用固定 shape,选中的 capture bucket。 + +例如 capture 过这些 bucket: + +- `[1, 2, 4, 8, 16, 32, 48]` + +如果这次真实请求数是 3,系统可能会选择: + +- `bs = 4` + +然后: + +- 前 3 个位置放真实请求 +- 第 4 个位置放 padding / fill value + +### 2.3 `num_tokens` + +这不是永远等于 `bs`。 + +它取决于当前模式下“每个 request 本轮贡献多少 query token”。 + +几个典型场景: + +| 场景 | `num_tokens` | +|------|--------------| +| 普通 decode | `bs * 1` | +| target verify | `bs * num_draft_tokens` | +| draft decode | `bs * topk` | +| draft extend | `bs * (speculative_num_steps + 1)` | +| 普通 prefill / extend | 通常是 `sum(extend_seq_lens)`,不一定等于 `bs * 常数` | + +这也是为什么: + +- 很多 layer 看到的输入 shape 是 `[num_tokens, hidden_size]` +- 而 graph bucket 却还是按 `bs` 来管理 + + +## 3. SGLang 为什么默认把 graph 重点放在 decode + +`SGLang` 的通用 `CudaGraphRunner` 默认 capture 的 forward mode 是 `DECODE`: + +- 初始化时先设: + - `capture_forward_mode = ForwardMode.DECODE` + - `num_tokens_per_bs = 1` +- 若是 speculative target verify,再切成: + - `ForwardMode.TARGET_VERIFY` + - `num_tokens_per_bs = speculative_num_draft_tokens` +- 若是 `DLLM_EXTEND`,再切成 block-size 固定模式 + +关键点是: + +- 这些模式都满足“每个 request 贡献固定个数的 query token” + +而普通 `prefill / extend` 不满足这一点。 + +从 `sglang/python/sglang/srt/model_executor/cuda_graph_runner.py` 可以直接看到这件事: + +- graph runner 默认按 `DECODE` 组织 +- `num_tokens_per_bs` 是固定常数 +- 再用它去算: + - `max_bs` + - `max_num_token` + - 静态输入 buffer 的大小 + + +## 4. 为什么 decode 比 prefill 更适合 graph + +### 4.1 decode 的 query 结构稳定 + +decode 下,一个 request 往往只算一个 query token。 + +因此: + +- `max_q_len` 通常固定为 `1` +- `num_tokens = bs` +- `qo_indptr` 结构非常规则 +- kernel 形状更容易随 `bs bucket` 固定下来 + +即使 metadata 中像: + +- `kv_indptr` +- `kv_indices` +- `kv_last_page_len` + +这些值每轮都变,它们大多数时候也只是: + +- 作为固定 kernel 的输入索引参数 + +而不是决定“这次图长什么样”。 + +### 4.2 prefill / extend 的 query 结构是 ragged 的 + +prefill / extend 下,每个 request 这一轮要处理多少 query token,通常不一样。 + +例如: + +- request A 新增 3 个 token +- request B 新增 17 个 token +- request C 新增 1 个 token + +这时: + +- `num_tokens = 3 + 17 + 1` +- `qo_indptr` 随分布变化 +- `max_q_len` 随分布变化 +- `max_kv_len` 也随上下文长度变化 + +这不是简单的“值不同”,而是 batch 的 **几何结构** 不同。 + + +## 5. 为什么“metadata 改变”会阻碍 prefill graph 复用 + +这个问题最容易被误解。 + +### 5.1 先说清楚:metadata 变,不一定阻碍 capture + +如果你拿某一个固定的 prefill batch 去做 capture,这次 capture 可能是成功的。 + +因为那一刻: + +- `q.shape` +- `kv_indices.shape` +- `qo_indptr` +- `max_q_len` +- `max_kv_len` + +都是确定的。 + +所以问题不在“这次能不能录下来”,而在: + +- **下一次不同的 prefill batch 还能不能 replay 这张图** + +### 5.2 decode 中 metadata 更像“数据” + +decode 中,metadata 变化通常只是: + +- 不同 request 对应不同 KV 索引 +- 不同 request 当前上下文长度不同 + +但最终仍然是在执行同一类 decode kernel。 + +所以它们更像: + +- 同一张图里的输入数据 + +### 5.3 prefill 中 metadata 更像“图结构控制器” + +在普通 MLA extend / prefill 里,metadata 会直接影响: + +1. 走哪条 Python 分支 +2. 中间张量 shape +3. workspace 大小 +4. gather 结果长度 +5. kernel 的 `max_q_len / max_kv_len` + +这就是根本区别。 + + +## 6. 用 ATOM plugin 的 MLA extend 代码看这个问题 + +`ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` 里的 `MLA extend` 很能说明问题。 + +### 6.1 metadata 先决定本轮的 ragged 结构 + +普通 MLA extend 初始化时会根据当前 batch 更新: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `max_q_len` +- `max_kv_len` + +它们都不是常量,而是来自当前 batch 的 `extend_seq_lens / seq_lens`。 + +### 6.2 metadata 决定走哪条代码路径 + +在 `_forward_extend_mla_normal()` 里,代码会根据 prefix 情况和 cache 形态走不同分支: + +- 无 prefix +- 有 prefix 且要 decompress +- 有 prefix 且走 absorbed MLA + +这意味着: + +- 不同 batch 可能走完全不同的子函数 +- graph capture 录下来的并不是“抽象 extend”,而是“某一条具体 extend 分支” + +如果第一次 capture 时: + +- `extend_no_prefix = True` + +那录下来的是 `_extend_mla_no_prefix()` 这条图。 + +下一次如果: + +- `extend_no_prefix = False` + +且要走 `_extend_mla_absorbed_prefix()`,那已经不是同一张图。 + +### 6.3 metadata 决定中间 tensor 的 shape + +在 no-prefix prefill 路径里,会根据当前 query token 总数构造: + +- `temp_kv_indices` +- `output` + +它们的 shape 直接依赖: + +- `q.shape[0]` +- `total_s` + +而 `q.shape[0]` 本身就是当前 ragged batch 展平后的 token 数。 + +换一个 prefill batch: + +- `total_s` 变了 +- 中间 tensor shape 跟着变 + +那 graph 也就不再可复用。 + +### 6.4 metadata 决定 workspace 的 shape + +FP8 prefill 路径里,还会根据: + +- `reduce_partial_map.size(0)` +- `total_s` + +分配: + +- `logits` +- `attn_lse` +- `final_lse` +- `output` + +而 `reduce_partial_map` 正是从当前 batch 的分段结构推出来的。 + +所以这不是“kernel 读不同 metadata”这么简单,而是: + +- metadata 直接控制要分配多大的临时缓冲区 + +### 6.5 metadata 决定 gather 后张量长度 + +在 absorbed prefix 路径里,会先: + +- `k_selected = torch.index_select(K_Buffer, 0, kv_indices)` + +这里 `k_selected.shape[0]` 就等于: + +- `len(kv_indices)` + +而 `kv_indices` 的长度也是当前 batch 的结构量。 + +因此: + +- prefix KV gather 后的张量 shape 也会跟 batch 变化 + +这会继续向下游 kernel 传播。 + + +## 7. 为什么不能简单靠 padding 解决普通 prefill + +有人会自然想到: + +- 既然 decode 能靠 bucket + padding 做 graph +- 那 prefill 也可以 pad 到固定 `bs / max_q_len / max_kv_len` + +理论上不是完全不行,但工程上代价很大。 + +### 7.1 decode 的 padding 成本小 + +decode 一般每个 request 只处理一个 token。 + +所以即使: + +- `raw_bs = 3` +- `bs = 4` + +多 pad 一个 request 的成本也比较低。 + +### 7.2 prefill 的 padding 成本会放大 attention 计算 + +prefill attention 的成本接近: + +- query token 数 +- context 长度 +- ragged 结构 + +的组合增长。 + +如果为了 graph,把所有 request 都 pad 成: + +- 大 `max_q_len` +- 大 `max_kv_len` + +那么: + +- 无效 token 也要参与很多 attention 计算 +- mask / metadata 也会跟着变大 +- workspace 和显存开销也会膨胀 + +最后可能: + +- graph 省下来的 launch 开销 +- 远远抵不过 padding 带来的额外 attention FLOPs + + +## 8. 为什么 `TARGET_VERIFY` / `DRAFT_EXTEND` 又能 graph + +因为它们虽然也不是普通 decode,但仍然满足: + +- 每个 request 的 query token 数是固定常数 + +例如: + +- `TARGET_VERIFY` + - 每个 request 验证 `num_draft_tokens` 个 token +- `DRAFT_EXTEND` + - 每个 request 固定处理 `speculative_num_steps + 1` 个 token + +所以它们仍然可以用: + +- `bs bucket` +- `num_tokens_per_bs` + +来组织 graph。 + +换句话说: + +- 它们不是“完全自由的 ragged prefill” +- 而是“固定 token-per-request 的特殊 extend” + +因此 graph 化难度明显低于普通 prefill。 + + +## 9. SGLang 在 decode 阶段怎样做 CUDAGraph capture + +下面按实际代码链路讲。 + +### 9.1 第一步:决定 capture 模式和 bucket + +`CudaGraphRunner.__init__()` 中会: + +1. 设定 `capture_forward_mode` +2. 设定 `num_tokens_per_bs` +3. 通过 `get_batch_sizes_to_capture()` 得到 `capture_bs` +4. 算出: + - `max_bs` + - `max_num_token = max_bs * num_tokens_per_bs` + +这一步的意义是: + +- graph 系统先把“这类 forward 的形状规则”固定下来 +- 然后再一次性分配足够大的静态 buffer + +### 9.2 第二步:attention backend 先分配 graph 专用静态状态 + +接着会调用: + +- `attn_backend.init_cuda_graph_state(max_bs, max_num_token)` + +在 `ATOMAttnBackendForSgl` 里,这一步会分配 graph 期间复用的持久 buffer,例如: + +- `cuda_graph_kv_last_page_len` +- `cuda_graph_kv_indices` +- `page_table` +- `seq_lens` +- MLA decode 的 `work_metadata / work_indptr / work_info_set / reduce_*` + +这里的关键思想是: + +- graph replay 期间,不再频繁新建这些结构 +- 而是在固定 buffer 上反复更新其内容 + +### 9.3 第三步:为某个具体 bucket 构造静态输入视图 + +在 `capture_one_batch_size(bs)` 中,会从大 buffer 上切出本 bucket 对应的视图,例如: + +- `input_ids = buffers.input_ids[:num_tokens]` +- `req_pool_indices = buffers.req_pool_indices[:bs]` +- `seq_lens = buffers.seq_lens[:bs]` +- `positions = buffers.positions[:num_tokens]` + +然后构造一个 `ForwardBatch`: + +- `forward_mode = capture_forward_mode` +- `batch_size = bs` +- 大部分字段都直接绑定到这些静态 buffer 视图上 + +### 9.4 第四步:capture 前先初始化 attention metadata + +在真正 `graph capture` 之前,先调用: + +- `attn_backend.init_forward_metadata_capture_cuda_graph(...)` + +对 decode 来说,这一步本质上是: + +- 根据当前 `req_pool_indices / seq_lens` +- 把 `ForwardMetadata` 组装到 graph 专用的静态 buffer 视图上 + +对于 MLA decode,它会构造: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `work_metadata / work_indptr / work_info_set / reduce_*` + +而这些对象大多来自: + +- graph state 中预先分配好的持久 buffer + +### 9.5 第五步:跑几次 warmup,再进入真正 graph capture + +`capture_one_batch_size()` 里会: + +1. 先调用几次 `run_once()` +2. 再进入 `torch.cuda.graph(...)` +3. 把这次 forward 录下来 + +在这个过程中: + +- 输入 buffer 地址固定 +- metadata buffer 地址固定 +- forward mode 固定 +- kernel 形状固定 + +于是得到一张与 bucket `bs` 绑定的 graph。 + + +## 10. SGLang 在 decode 阶段怎样做 replay + +### 10.1 先从真实 batch 选一个 bucket + +在 `replay_prepare()` 中: + +1. 读取真实 batch 的: + - `raw_bs` + - `raw_num_token` +2. 从 `capture_bs` 中找一个: + - `bs >= raw_bs` + +这一步就是把真实 batch 映射到 graph bucket。 + +### 10.2 把真实数据 copy 到静态 buffer 的前缀 + +调用: + +- `buffers.populate_from_forward_batch(...)` + +会把真实 batch 的内容写入静态 buffer 的前缀区域,例如: + +- `input_ids[:raw_num_token]` +- `req_pool_indices[:raw_bs]` +- `seq_lens[:raw_bs]` +- `positions[:raw_num_token]` + +如果 `bs != raw_bs`,还会: + +- 用 fill value / zero 对后面的 padding 段做补齐 + +### 10.3 replay 前重建本轮 metadata + +随后调用: + +- `attn_backend.init_forward_metadata_replay_cuda_graph(...)` + +注意这一步非常关键: + +- graph replay 不是复用 capture 当时的 metadata 值 +- 而是复用 **metadata 的静态 buffer 与构造方式** +- 然后把本轮真实 batch 的索引内容重新写进去 + +也就是说: + +- 地址固定 +- 内容可变 + +对 decode 来说,这正是 graph 友好的做法。 + +### 10.4 最后 `graph.replay()` + +当静态 buffer 和 metadata 都准备好后,就直接: + +- `self.graphs[graph_key].replay()` + +执行那张已经 capture 好的图。 + +输出拿到后,再按照: + +- `raw_bs` +- `raw_num_token` + +把 padding 的尾部裁掉。 + + +## 11. Attention Metadata 在 graph capture / replay 中的角色 + +可以把 `ForwardMetadata` 在 graph 中的角色概括成一句话: + +- 它不是 graph 外的额外说明书 +- 它是 graph 里 attention kernel 的直接输入 + +但它在不同模式下的“地位”不同。 + +### 11.1 decode 中:metadata 更像固定地址上的输入参数 + +decode graph 下,像: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` + +更多是在表达: + +- 当前 batch 的 KV 可见范围 +- 当前 batch 的 query 分段 + +这些值会变,但: + +- 它们所在的 buffer 地址固定 +- 它们的 shape 规则受 bucket 控制 +- 下游还是同一类 decode kernel + +所以 graph 可以复用。 + +### 11.2 prefill 中:metadata 会升级成“图结构的一部分” + +prefill / extend 下,metadata 往往不仅仅被 kernel 读取,还会影响: + +- 选择哪条路径 +- 构造哪些中间 tensor +- 中间 tensor 有多大 +- workspace 有多大 +- kernel 看到的 `max_q_len / max_kv_len` + +因此: + +- metadata 变化会把“图长什么样”一起改掉 + +这就是它阻碍 graph 复用的根本原因。 + + +## 12. `ATOMAttnBackendForSgl` 中 graph metadata 的几个关键点 + +### 12.1 `init_cuda_graph_state()`:先分配 graph 专用持久 buffer + +plugin backend 里专门分配了: + +- `cuda_graph_kv_last_page_len` +- `cuda_graph_kv_indices` +- `page_table` +- `seq_lens` +- MLA decode 的 persistent workspace + +这样 replay 时就能复用这些地址。 + +### 12.2 `init_forward_metadata_capture_cuda_graph()`:把 bucket 数据写成 metadata + +这一步会根据当前 mode 做不同初始化: + +- `decode_or_idle` +- `target_verify` +- `draft_extend` + +每种模式都把: + +- bucket 对应的 `bs` +- 固定的 `num_tokens_per_bs` +- 当前 request 索引和 seq_lens + +转成 kernel 需要的 metadata。 + +### 12.3 `init_forward_metadata_replay_cuda_graph()`:重建本轮 metadata + +replay 时,plugin backend 不会继续沿用 capture 时那一轮的 metadata 值,而是: + +- 在固定 graph buffer 上 +- 根据本轮真实 batch 重建一次 metadata + +这一步必须非常小心“当前 bucket 视图”和“整块静态 buffer”的区别。 + +最近在 debug 中出现的一个典型 graph-only bug 正是: + +- 上游 replay 某条 speculative 路径把整块静态 buffer 传下来 +- plugin backend 按“已经是当前 `bs` 视图”去理解 +- 于是出现: + - `bs = 1` + - `seq_lens.shape[0] = 48` + +后来在 plugin backend 里做了统一切片规整: + +- `req_pool_indices = req_pool_indices[:bs]` +- `seq_lens = seq_lens[:bs]` +- `seq_lens_cpu = seq_lens_cpu[:bs]` + +本质上就是把 replay 的输入重新对齐到“当前 bucket 视图”。 + + +## 13. 这次 debug 暴露出的两个 graph-only 经验 + +### 13.1 backend 选型必须真的落到 plugin backend + +之前 `kv_last_page_len` 掉到 CPU 的问题,最后定位到: + +- `AiterMultiStepDraftBackend` 内部直接实例化 `AiterAttnBackend` +- 绕过了 plugin 通过 registry 注册的 `"aiter" -> ATOMAttnBackendForSgl` + +这说明: + +- graph-only 路径里,某些 backend 可能不是从常规 registry 路径拿到的 +- 如果 direct construction 没 patch 到,graph state 就可能偷偷回落到 upstream 实现 + +### 13.2 replay 必须明确区分“静态大 buffer”和“当前 bucket 视图” + +graph replay 中,静态 buffer 通常按: + +- `max_bs` +- `max_num_token` + +一次性分配。 + +但 backend 在构 metadata 时,真正应该看到的是: + +- 当前 bucket 的前 `bs` +- 当前 token 的前 `num_tokens` + +一旦把整块静态 buffer 当成当前视图使用,就很容易出现: + +- shape mismatch +- CPU / CUDA tensor 混用 +- metadata 与实际 batch 不一致 + + +## 14. 用一句工程化的话总结 + +如果只用一句最工程化的话来总结这篇文档: + +- `decode` 图里,metadata 大多是 **固定形状 graph 的输入数据** +- `prefill` 图里,metadata 往往会变成 **决定图形状和执行路径的结构量** + +因此: + +- `decode` 适合用 bucket + padding + 静态 buffer 做 graph +- 普通 `prefill / extend` 则很难在收益合理的前提下复用同一张 graph + + +## 15. 最后总结 + +记住下面五句话就够了: + +1. `CUDAGraph` 固定的是一整段具体 CUDA 执行计划,不只是 Python 函数入口。 +2. `raw_bs` 是真实请求数,`bs` 是 graph bucket,`num_tokens` 是真正传给很多 layer 的 token 数。 +3. `decode` 更适合 graph,因为每个 request 的 query 结构更稳定,metadata 更像输入参数。 +4. 普通 `prefill / extend` 难 graph,因为 metadata 会影响分支、shape、workspace 和 kernel 形态,升级成图结构的一部分。 +5. 在 `SGLang + ATOM plugin` 里,graph replay 的关键不是“重复使用旧 metadata 值”,而是“在固定 buffer/地址上重建本轮 metadata 内容”。 diff --git a/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md b/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md new file mode 100644 index 000000000..c2e4e4b6b --- /dev/null +++ b/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md @@ -0,0 +1,866 @@ +# 最简单 Prefill、CUDAGraph 与 Metadata 速查 + +## 文档目的 + +这篇文档只回答一个收窄后的问题: + +- **不考虑不同 kernel path** +- **不考虑 prefix cache** +- **不考虑 speculative target_verify / draft_extend** +- **只考虑最普通、最简单的 prefill / extend** + +在这个前提下,说明: + +1. 为什么这种最简单的 prefill 仍然难做 `CUDAGraph` +2. attention metadata 的核心字段在这种场景下分别表示什么 +3. 给几个可以手算的小例子,方便以后速查 + + +## 一句话结论 + +即使只看最简单 prefill,`CUDAGraph` 的挑战仍然存在。根本原因不是“metadata 会变”本身,而是: + +- `total_tokens` 会变 +- `max_q_len / max_kv_len` 会变 +- ragged metadata 会跟着 batch 几何结构一起变 +- 很多中间 tensor / workspace 的 shape 也会变 + +所以问题不是: + +- kernel 能不能读取不同的 metadata 值 + +而是: + +- **同一个 prefill batch family,能不能稳定成一张固定 shape 的图** + + +## 1. 本文说的“最简单 prefill”是什么 + +这里约定的“最简单 prefill”是: + +- 没有 prefix cache +- 没有 speculative 分支 +- 不讨论不同 kernel path 的切换 +- 假定已经选中某一条固定的 prefill kernel 路径 +- 一个 batch 里有若干 request +- 每个 request 本轮需要处理若干 query token +- attention 以 ragged / varlen 形式运行 + +可以把它理解成: + +- `ForwardMode.EXTEND` +- `spec_info = None` +- `extend_prefix_lens = 0` +- attention backend 已经决定“就走这条 prefill kernel” + +本文不讨论: + +- prefix/no-prefix 的 kernel 分流 +- absorbed / decompress 等 MLA 专有分流 +- draft_extend / target_verify +- decode + + +## 2. 最简单 prefill 的数据形状 + +prefill 和 decode 最大的不同在于: + +- decode 常常是每个 request 本轮只算 1 个 token +- prefill 常常是每个 request 本轮要算多个 token + +因此,很多 layer 真正看到的不是: + +- `[bs, hidden_size]` + +而是: + +- `[total_tokens, hidden_size]` + +其中: + +- `bs` = request 数 +- `total_tokens` = 本轮所有 request 的 query token 总数 + +在最简单 prefill 下,常见关系是: + +```text +total_tokens = sum(extend_seq_lens) +``` + +这意味着: + +- 即使 `bs` 不变 +- 只要每个 request 的长度分布变了 +- `total_tokens` 就会变 + + +## 3. 为什么最简单 prefill 仍然难做 CUDAGraph + +下面只看“最简单 prefill”,不引入分支复杂度。 + +### 3.1 `q/k/v/o` 的 token 维会变 + +prefill 下最直观的问题就是: + +- `q.shape[0] = total_tokens` +- `k.shape[0] = total_tokens` +- `v.shape[0] = total_tokens` +- `o.shape[0] = total_tokens` + +只要: + +- request 数不同 +- 或每个 request 的 query 长度分布不同 + +那么: + +- `total_tokens` 就不同 +- 上面这些张量 shape 就不同 + +而 `CUDAGraph` 更喜欢的是: + +- tensor shape 固定 +- graph 中 kernel launch 形态固定 + +这已经是第一层挑战。 + +### 3.2 `max_q_len / max_kv_len` 会变 + +即使不看 `q.shape[0]`,varlen attention 往往还会显式传: + +- `max_q_len` +- `max_kv_len` +- `cu_seqlens_q` 或 `qo_indptr` + +这些量不是 decoration,而是 kernel 的核心输入。 + +例如对于一个 batch: + +- request A: 3 tokens +- request B: 2 tokens + +则: + +- `max_q_len = 3` + +如果下一个 batch 是: + +- request A: 4 tokens +- request B: 1 token + +则: + +- `max_q_len = 4` + +虽然: + +- 两个 batch 的 `bs = 2` +- 两个 batch 的 `total_tokens = 5` + +但: + +- `qo_indptr` 不同 +- `max_q_len` 不同 + +这意味着: + +- 内部 tile / launch 策略可能不同 +- workspace 需求也可能不同 + +### 3.3 ragged metadata 在描述“问题几何结构” + +在 decode 里,很多 metadata 更像: + +- 固定图上的输入索引数据 + +而在 prefill 里,metadata 往往在表达: + +- 一共有多少 query token +- 这些 query token 怎样按 request 分段 +- KV token 怎样按 request 分段 +- 当前 batch 的最大 query / KV 长度是多少 + +所以它不只是“值会变”,而是在描述: + +- **这轮 attention 问题本身长什么样** + +这就让同一张图更难复用。 + +### 3.4 中间 tensor / workspace 的 shape 也会变 + +哪怕我们强行假设: + +- kernel path 不变 + +很多中间结构也仍然可能随 batch 变化。 + +例如: + +- 某些临时索引张量长度跟 `total_tokens` 走 +- 某些 workspace 大小跟 `max_q_len / max_kv_len` 走 +- 某些 reduce buffer 大小跟分段结构走 + +所以问题不止在输入张量,而是: + +- graph 内部很多“中间物体”的 shape 也不稳定 + +### 3.5 同一个分支里也可能无法稳定 replay + +这点最容易误解。 + +即使你已经保证: + +- 一定走同一个 prefill kernel path + +也不代表可以 graph。 + +因为同一条路径里仍然可能有: + +- `total_tokens` 变化 +- `max_q_len` 变化 +- `max_kv_len` 变化 +- workspace shape 变化 + +所以: + +- “分支固定” + +并不等于: + +- “图固定” + + +## 4. Metadata 速查表 + +下面只保留最常用于“最简单 prefill”理解的字段。 + +### 4.1 高层 batch 字段 + +| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | +|------|------------|----------|-----------------------| +| `bs` | Python `int` | request 数 | 当前 batch 有几个 request | +| `extend_seq_lens` | `[bs]` | 每个 request 本轮 query token 数 | 决定 `total_tokens` 和 `qo_indptr` | +| `seq_lens` | `[bs]` | 每个 request 当前可见 KV 长度 | 决定 `kv_indptr` 和 `max_kv_len` | +| `seq_lens_sum` | Python `int` | 所有 request KV 长度总和 | 常用于辅助构造 KV metadata | +| `req_pool_indices` | `[bs]` | request 在 `req_to_token` 里的行号 | 用来从映射表里取物理 KV slot | + +### 4.2 Query 侧 metadata + +| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | +|------|------------|----------|-----------------------| +| `qo_indptr` | `[bs + 1]` | flatten 后每个 request 的 query 段边界 | 告诉 kernel 哪些 query 属于哪个 request | +| `max_q_len` | Python `int` | batch 内单 request 最大 query 长度 | kernel 的长度上限参数 | +| `total_tokens` | Python `int` | flatten 后 query token 总数 | 决定很多输入/输出第一维 | + +### 4.3 KV 侧 metadata + +| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | +|------|------------|----------|-----------------------| +| `kv_indptr` | `[bs + 1]` | flatten 后每个 request 的 KV 段边界 | 告诉 kernel 每段 KV 从哪里到哪里 | +| `kv_indices` | `[sum(seq_lens)]` 或相近长度 | flatten 后每个 KV token 对应的物理 slot | 真正告诉 kernel 去读哪些物理 KV | +| `max_kv_len` | Python `int` | batch 内单 request 最大 KV 长度 | kernel 的 KV 长度上限参数 | +| `kv_last_page_len` | `[bs]` | 每个 request 最后一页有效 token 数 | paged MLA kernel 常用 | +| `kv_lens` | `[bs]` | 每个 request 当前 KV 长度 | 在 page-table 表达里常用 | +| `page_table` | `[bs, max_pages]` | request 到 page id 的二维映射 | 非 MLA / page-table 风格 backend 常用 | + + +## 5. 这些字段的物理意义,最简单地怎么记 + +### 5.1 `qo_indptr` + +记法: + +- 它是 query 侧的 CSR 前缀和边界表 + +典型 shape: + +- `[bs + 1]` + +dtype: + +- 通常是 `int32` + +含义: + +- 第 `i` 个 request 的 query 在 flatten Q 中的范围是: + - `[qo_indptr[i], qo_indptr[i+1])` + +和哪些量对应: + +- `qo_indptr[0]` 固定是 `0` +- `qo_indptr[-1]` 通常等于: + - `total_tokens` +- `qo_indptr[i + 1] - qo_indptr[i]` 等于: + - 第 `i` 个 request 的 query 长度 + +### 5.2 `kv_indptr` + +记法: + +- 它是 KV 侧的 CSR 前缀和边界表 + +典型 shape: + +- `[bs + 1]` + +dtype: + +- 通常是 `int32` + +含义: + +- 第 `i` 个 request 的 KV 在 flatten `kv_indices` 中的范围是: + - `[kv_indptr[i], kv_indptr[i+1])` + +和哪些量对应: + +- `kv_indptr[0]` 固定是 `0` +- `kv_indptr[-1]` 通常等于: + - `len(kv_indices)` +- `kv_indptr[i + 1] - kv_indptr[i]` 等于: + - 第 `i` 个 request 当前参与 attention 的 KV 长度 + +### 5.3 `kv_indices` + +记法: + +- 它是“这次 attention 真正要访问的物理 KV slot 列表” + +典型 shape: + +- `[sum(seq_lens)]` +- 更严格一点说: + - `[kv_indptr[-1]]` + +dtype: + +- 通常是 `int32` + +含义: + +- 每个元素都是一个 physical KV slot id + +更具体一点: + +- `kv_indices` 不是“第几个 token” +- 也不是“第几个 request” +- 它是: + - **flatten 后,每个 KV token 在物理 KV cache 里的实际位置** + +它和下面几个量要一起看: + +- `req_pool_indices` + - shape 通常是 `[bs]` + - 告诉你“当前 batch 里每个 request 对应 `req_to_token` 的哪一行” +- `req_to_token` + - shape 通常是 `[req_pool_size, max_context_len]` + - 告诉你“这个 request 的逻辑第 `j` 个 token,物理上写在 KV cache 的哪个 slot” +- `seq_lens` + - shape 通常是 `[bs]` + - 告诉你“这个 request 当前有多少个 KV token 参与 attention” +- `kv_indptr` + - shape 通常是 `[bs + 1]` + - 告诉你“这个 request 对应的 KV 段,在 flatten 后 `kv_indices` 里的哪一段” + +所以可以把 `kv_indices` 理解成: + +- 先按 `req_pool_indices` 找到每个 request 在 `req_to_token` 中的那一行 +- 再按 `seq_lens[i]` 取出这行前面的有效 token 映射 +- 最后把所有 request 的映射段拼接起来 + +也就是说: + +- `kv_indptr` 负责“分段边界” +- `kv_indices` 负责“段内具体有哪些 physical slot” + +### 5.3.1 它为什么重要 + +attention kernel 真正关心的不是: + +- “这是第几个逻辑 token” + +而是: + +- “要去 KV cache 的哪个物理位置读 K/V” + +`kv_indices` 正是在回答这个问题。 + +如果没有 `kv_indices`,kernel 只知道: + +- batch 里有几个 request +- 每个 request 长度是多少 + +但它仍然不知道: + +- 这些 request 的历史 token 到底落在 KV cache 里的哪些 physical slot 上 + +### 5.3.2 它为什么通常是 flatten 的 + +`kv_indices` 做成一维 flatten 形式,而不是二维 `[bs, max_kv_len]`,是因为: + +- 不同 request 的 KV 长度不一样 +- ragged attention 更自然的表示法就是: + - 一条长数组 + - 再配一个 `kv_indptr` + +这和 CSR 稀疏矩阵的表达方式很像: + +- `kv_indices` = 数据主体 +- `kv_indptr` = 每段边界 + +### 5.3.3 它和 `total_tokens` / `max_kv_len` 的区别 + +这几个量很容易混: + +- `total_tokens` + - shape 是标量 / Python `int` + - query 侧总 token 数 +- `max_kv_len` + - shape 是标量 / Python `int` + - 单 request 最大 KV 长度 +- `kv_indices` + - shape 是一维张量 `[sum(seq_lens)]` + - 这轮 attention 真正要访问的所有 physical KV slot 列表 + +它们不是一回事。 + +例如: + +- `bs = 2` +- `seq_lens = [3, 2]` + +那么: + +- `max_kv_len = 3` +- `len(kv_indices) = 5` + +前者是“最大段长度”,后者是“所有段拼起来后的总长度”。 + +### 5.3.4 一个更完整的手算例子 + +假设: + +- `req_pool_indices = [7, 9]` + - shape: `[2]` +- `seq_lens = [5, 3]` + - shape: `[2]` +- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` +- `req_to_token[9, 0:3] = [200, 201, 220]` + +那么先算边界: + +```text +kv_indptr = [0, 5, 8] +``` + +它的 shape 是: + +- `[3]`,也就是 `[bs + 1]` + +再按每个 request 的有效长度取映射: + +- request 0 取: + - `[100, 101, 102, 103, 120]` +- request 1 取: + - `[200, 201, 220]` + +最后拼接得到: + +```text +kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] +``` + +它的 shape 是: + +- `[8]` +- 也就是: + - `[sum(seq_lens)] = [5 + 3]` + +于是: + +- request 0 的 KV 段是: + - `kv_indices[kv_indptr[0]:kv_indptr[1]]` + - 也就是 `kv_indices[0:5]` +- request 1 的 KV 段是: + - `kv_indices[kv_indptr[1]:kv_indptr[2]]` + - 也就是 `kv_indices[5:8]` + +### 5.3.5 debug 时怎么看 `kv_indices` + +如果你在 debug attention metadata,`kv_indices` 最值得看两件事: + +1. 长度对不对 + +- 在最简单 prefill 里,通常应该有: + - `len(kv_indices) == sum(seq_lens)` +- 也可以写成: + - `kv_indices.shape == (int(seq_lens.sum()),)` + +如果这个关系都不对,说明: + +- `kv_indptr` +- `seq_lens` +- 或 `req_to_token` 的使用 + +有地方没对齐。 + +2. 分段内容对不对 + +给定: + +- `kv_indptr` + - shape: `[bs + 1]` +- `kv_indices` + - shape: `[sum(seq_lens)]` + +你应该能把每个 request 对应的 physical slot 段切出来,并和: + +- `req_to_token[row, :seq_len]` + +一一对应上。 + +如果切出来的段和 `req_to_token` 不对应,常见意味着: + +- `req_pool_indices` 行号不对 +- `seq_lens` 不是这轮应看的 KV 长度 +- 或者 graph replay 时把“整块静态 buffer”错当成了“当前 bucket 视图” + +### 5.4 `max_q_len` + +记法: + +- 这轮 batch 中,单 request 最长的 query 长度 + +shape: + +- 标量 / Python `int` + +它不是: + +- 所有 query 总数 + +### 5.5 `max_kv_len` + +记法: + +- 这轮 batch 中,单 request 最长的 KV 长度 + +shape: + +- 标量 / Python `int` + + +## 6. 最简单 prefill 的三个小例子 + +### 例子 1:单 request prefill + +假设: + +- `bs = 1` +- `extend_seq_lens = [5]` + - shape: `[1]` +- `seq_lens = [5]` + - shape: `[1]` + +那么: + +- `total_tokens = 5` +- `qo_indptr = [0, 5]` + - shape: `[2]` +- `kv_indptr = [0, 5]` + - shape: `[2]` +- `max_q_len = 5` +- `max_kv_len = 5` + +如果 `req_to_token[row]` 对应的是: + +- `[100, 101, 102, 103, 104]` + +那么: + +- `kv_indices = [100, 101, 102, 103, 104]` + - shape: `[5]` + +这个例子很简单,但也正好说明: + +- 这轮 graph 里主干很多 tensor 第一维都是 `5` + +如果下一个 batch 变成 `7` 个 token: + +- 图里的很多 shape 就都要变 + +### 例子 2:两个 request,长度不同 + +假设: + +- `bs = 2` +- `extend_seq_lens = [3, 2]` + - shape: `[2]` +- `seq_lens = [3, 2]` + - shape: `[2]` + +那么: + +- `total_tokens = 5` +- `qo_indptr = [0, 3, 5]` + - shape: `[3]` +- `kv_indptr = [0, 3, 5]` + - shape: `[3]` +- `max_q_len = 3` +- `max_kv_len = 3` + +如果: + +- request 0 的物理 slot 是 `[10, 11, 12]` +- request 1 的物理 slot 是 `[20, 21]` + +那么: + +- `kv_indices = [10, 11, 12, 20, 21]` + - shape: `[5]` + +这里最值得注意的是: + +- `total_tokens = 5` +- 但 request 分段结构已经不是均匀的 + +### 例子 2.1:把 `qo_indptr + kv_indptr + kv_indices` 放在一起看 + +继续沿用上面的 batch: + +- `bs = 2` +- `extend_seq_lens = [3, 2]` +- `seq_lens = [3, 2]` +- `qo_indptr = [0, 3, 5]` +- `kv_indptr = [0, 3, 5]` +- `kv_indices = [10, 11, 12, 20, 21]` + +如果把 flatten 后的 Q token 记成: + +```text +Q_flat = [q0, q1, q2, q3, q4] +``` + +那么 query 侧分段是: + +- request 0: + - `Q_flat[0:3]` + - 也就是 `q0, q1, q2` +- request 1: + - `Q_flat[3:5]` + - 也就是 `q3, q4` + +因为: + +```text +qo_indptr = [0, 3, 5] +``` + +同样,KV 侧分段是: + +- request 0: + - `kv_indices[0:3]` + - 也就是 `[10, 11, 12]` +- request 1: + - `kv_indices[3:5]` + - 也就是 `[20, 21]` + +因为: + +```text +kv_indptr = [0, 3, 5] +kv_indices = [10, 11, 12, 20, 21] +``` + +把它们并排看,就是: + +```text +request 0: + Q range = [qo_indptr[0], qo_indptr[1]) = [0, 3) + Q tokens = [q0, q1, q2] + KV range = [kv_indptr[0], kv_indptr[1]) = [0, 3) + KV slots = [10, 11, 12] + +request 1: + Q range = [qo_indptr[1], qo_indptr[2]) = [3, 5) + Q tokens = [q3, q4] + KV range = [kv_indptr[1], kv_indptr[2]) = [3, 5) + KV slots = [20, 21] +``` + +这就是 ragged attention metadata 最核心的意思: + +- `qo_indptr` + - 告诉 kernel:flatten 后哪些 query 属于哪个 request +- `kv_indptr` + - 告诉 kernel:flatten 后哪些 KV 段属于哪个 request +- `kv_indices` + - 告诉 kernel:这个 request 的 KV 段具体对应哪些 physical KV slot + +如果再把 `req_to_token` 写出来: + +```text +req_to_token[row_of_req0, 0:3] = [10, 11, 12] +req_to_token[row_of_req1, 0:2] = [20, 21] +``` + +那就能看到: + +- `kv_indices` + 本质上就是把每个 request 在 `req_to_token` 里的有效前缀切出来,再按 request 顺序拼起来。 + +### 例子 3:`total_tokens` 一样,但 graph 仍然难复用 + +看两个 batch: + +#### batch A + +- `bs = 2` +- `extend_seq_lens = [3, 2]` + - shape: `[2]` + +得到: + +- `total_tokens = 5` +- `qo_indptr = [0, 3, 5]` + - shape: `[3]` +- `max_q_len = 3` + +#### batch B + +- `bs = 2` +- `extend_seq_lens = [4, 1]` + - shape: `[2]` + +得到: + +- `total_tokens = 5` +- `qo_indptr = [0, 4, 5]` + - shape: `[3]` +- `max_q_len = 4` + +这两个 batch: + +- `bs` 相同 +- `total_tokens` 相同 + +但: + +- `qo_indptr` 不同 +- `max_q_len` 不同 + +这说明: + +- 即使总 token 数没变 +- prefill 的“问题几何结构”仍然变了 + +这就是 graph 复用困难的关键例子。 + +### 例子 4:为什么 decode 更容易 graph + +假设 decode: + +- `bs = 2` +- 每个 request 本轮只解 1 个 token + +那么: + +- `total_tokens = 2` +- `qo_indptr = [0, 1, 2]` + - shape: `[3]` +- `max_q_len = 1` + +下一个 batch 只要 bucket 还是这个 `bs`,即使: + +- `kv_indices` +- `seq_lens` +- `kv_indptr` + +的内容变了,graph 里主干 shape 往往还是稳定得多。 + +所以: + +- decode 中 metadata 更像“数据表” +- prefill 中 metadata 更像“几何结构描述” + + +## 7. 如果硬要对最简单 prefill 做 graph,需要什么条件 + +最少需要做下面几件事中的一些: + +### 7.1 固定 `bs` + +最基础的 bucket 化: + +- 只允许某几个 `bs` 值 + +但仅固定 `bs` 还不够。 + +### 7.2 固定 `total_tokens` + +因为很多输入/输出 tensor 的第一维是: + +- `total_tokens` + +若它不固定,graph 很难复用。 + +### 7.3 固定 `max_q_len / max_kv_len` + +因为它们常常影响: + +- kernel launch 形态 +- workspace 大小 + +### 7.4 固定 workspace 形状 + +也就是说: + +- 需要让中间临时张量有固定上限 +- 或者直接预分配到某个 bucket 上限 + +### 7.5 允许 padding / pack / unpack + +最现实的手段通常是: + +- graph 外把 ragged batch 归一化 +- graph 内只处理固定形状张量 +- graph 后再 unpad + +但代价是: + +- 额外数据搬运 +- padding 带来的无效计算 + + +## 8. 为什么这比 decode 难很多 + +可以用一句最简单的话来对比: + +- `decode` 的不确定性主要是“数据值不同” +- `prefill` 的不确定性主要是“问题结构不同” + +decode 常常可以做到: + +- 固定 `num_tokens_per_bs = 1` +- 固定 `max_q_len = 1` +- 只靠 `bs bucket` 就稳定大部分形状 + +而最简单 prefill 仍然会遇到: + +- `total_tokens` 变化 +- `qo_indptr` 变化 +- `max_q_len` 变化 +- `max_kv_len` 变化 +- 中间 workspace 变化 + + +## 9. 最后总结 + +只记下面六句话就够了: + +1. 最简单 prefill 也不是固定 shape 问题,而是 ragged / varlen 问题。 +2. `total_tokens = sum(extend_seq_lens)`,它决定了很多主干张量的第一维。 +3. `qo_indptr` 和 `kv_indptr` 不是装饰字段,而是在描述这轮 attention 的分段几何结构。 +4. `max_q_len / max_kv_len` 会随着 batch 分布变化,常常进一步影响 kernel 和 workspace。 +5. 即使不考虑不同 kernel path,prefill 仍然可能因为 shape 和 workspace 不稳定而难以复用同一张 graph。 +6. 如果真的想 graph 化最简单 prefill,通常还需要 bucket 化、padding 或 pack/unpack 来先把 ragged 问题归一化。 diff --git a/work_log/MTP/MTP-2026-04-08.md b/work_log/MTP/MTP-2026-04-08.md new file mode 100644 index 000000000..ad368de3a --- /dev/null +++ b/work_log/MTP/MTP-2026-04-08.md @@ -0,0 +1,525 @@ +# 2026-04-08 MTP 调研与调试记录 + +## 目标 + +本次工作的目标是调研并尝试推进 `ATOM + SGLang plugin` 路径下的 +DeepSeek MTP 接入,重点回答下面几个问题: + +- `ATOM/atom/plugin/sglang` 当前到底支持了什么 +- upstream SGLang 的 DeepSeek MTP / NextN 是怎么组织的 +- 当前启动命令实际跑起来时,target model 和 draft model 分别是谁 +- 当前失败点落在什么地方,根因是什么 +- 如果后续正式推进,推荐的技术路线是什么 + + +## 本次结论速览 + +- `ATOM sglang plugin` 当前并没有真正把 `ATOM/atom/models/deepseek_mtp.py` + 接到 draft/MTP 路径上。 +- 当前运行形态更接近: + - target model 走 `ATOM plugin wrapper + ATOM DeepseekV3ForCausalLM` + - draft model 走 upstream SGLang 的 `DeepseekV3ForCausalLMNextN` +- 换句话说,当前不是“ATOM MTP 已经接通”,而是: + - `ATOM target + SGLang NextN draft + ATOM target verify backend` +- 本次已经修掉了第一个显式接口不兼容问题: + - upstream speculative worker 需要 target model 提供 + `get_embed_and_head()/set_embed_and_head()/set_embed()` + - `ATOM plugin wrapper` 原本没有这些接口 +- 当前新的阻塞点在: + - `TARGET_VERIFY` 路径进入 `ATOM` 的 + `sgl_attn_backend.py` + - ATOM plugin 把 verify 当成普通 extend 处理 + - 于是错误访问了 `forward_batch.extend_seq_lens` + - 但在 verify 路径下这个字段本来就可能为 `None` + + +## 背景知识 + +### 1. upstream SGLang 的 DeepSeek MTP / NextN 组织方式 + +在 upstream SGLang 里,DeepSeek 的 draft/MTP 不是通过一个独立的 +`DeepSeekMTP` 类来暴露给 speculative runtime,而是通过一个 +SGLang 风格的 draft model 壳: + +- `sglang/python/sglang/srt/models/deepseek_nextn.py` +- 类名:`DeepseekV3ForCausalLMNextN` + +这层壳的特点: + +- 对外长得像标准的 `ForCausalLM` +- 能直接被 `ModelRegistry` 解析和实例化 +- 带有 `load_weights(..., is_nextn=True)` +- 带有 `get_embed_and_head()/set_embed_and_head()` +- 能直接对接 SGLang speculative worker + +它内部并不会真的再构一个完整的 target DeepSeek 模型,而是构一个 +更薄的 NextN draft 结构。 + + +### 2. ATOM 的 MTP 组织方式 + +ATOM 侧则是另一种设计: + +- `ATOM/atom/models/deepseek_mtp.py` +- 类名:`DeepSeekMTP` + +这更像一个 draft core,而不是一个完整的 SGLang 风格 runtime wrapper。 +它暴露的是: + +- `forward(input_ids, positions, hidden_states, ...)` +- `compute_logits(hidden_states, spec_step_idx=...)` + +也就是说: + +- upstream SGLang:偏“运行时壳子” +- ATOM:偏“底层 draft 模型核心” + + +### 3. 为什么这是个关键差异 + +这意味着 plugin 端后续有两种思路: + +1. 继续沿用 upstream 的思路,在 plugin 里做一个 + `DeepseekV3ForCausalLMNextN` 风格的壳 +2. 尽量复用 `ATOM/atom/models/deepseek_mtp.py`,只补一层很薄的 + SGLang 兼容 wrapper + +本次调研后的倾向是: + +- 不建议在 plugin 里再复制一整套 upstream NextN 继承链 +- 更推荐“上层保留 SGLang 兼容接口,下层复用 ATOM DeepSeekMTP” + + +### 4. speculative 运行时里的几个对象容易混淆 + +在 SGLang speculative 模式下,scheduler 里会出现多个 worker: + +- `tp_worker` + - target `TpModelWorker` +- `draft_worker` + - 变量名容易误导 + - 在 scheduler 里它其实通常是 speculative orchestrator + - 例如 `EAGLEWorker` / `EAGLEWorkerV2` +- 真正的 draft `TpModelWorker` + - 在 orchestrator 内部 + +所以: + +- `self.model_worker = self.draft_worker` + +并不是“target worker 被 draft worker 替代”,而是: + +- scheduler 把统一执行入口切到了 speculative orchestrator +- orchestrator 再内部协调: + - draft propose + - target verify + - draft extend + + +### 5. `embed_and_head` 是什么,为什么 drafter 需要 + +upstream speculative worker 在初始化 draft model 时,会从 target model 取: + +- `embed = embed_tokens.weight` +- `head = lm_head.weight` + +原因: + +- drafter 需要把 token id 变成 embedding,再继续往下算 +- drafter 也需要把 hidden state 变成 logits,提议下一个 token +- 共享 target 的 embedding / lm_head 可以: + - 节省显存 + - 保持 vocab 完全一致 + - 避免 draft 再重复加载一份大权重 + + +## 当前代码状态理解 + +### 1. ATOM plugin 当前只导出了哪些 model + +文件: + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` +- `ATOM/atom/plugin/register.py` + +当前 external package 暴露的 `_MODEL_NAMES` 只有: + +- `DeepseekV3ForCausalLM` +- `Qwen3MoeForCausalLM` + +ATOM plugin 支持的 `_ATOM_SUPPORTED_MODELS` 里也没有: + +- `DeepseekV3ForCausalLMNextN` +- `DeepSeekMTPModel` + +这意味着: + +- target `DeepseekV3ForCausalLM` 可以被 ATOM external package 覆盖 +- draft `DeepseekV3ForCausalLMNextN` 不会被 ATOM external package 覆盖 + + +### 2. 为什么 target 在 `prepare_model()` 里还是 `DeepseekV3ForCausalLM` + +文件: + +- `ATOM/atom/plugin/prepare.py` +- `sglang/python/sglang/srt/configs/model_config.py` + +我们在 `prepare_model()` 打日志时看到: + +- `model_arch in prepare_model: DeepseekV3ForCausalLM` + +这并不矛盾,因为那个 `prepare_model()` 调用发生在 target 路径。 + +而 draft 路径是另外一条 worker 初始化链,且在 `ModelConfig._config_draft_model()` +里会把: + +- `DeepseekV3ForCausalLM` + -> `DeepseekV3ForCausalLMNextN` + +所以: + +- target 看到 `DeepseekV3ForCausalLM` 正常 +- draft 并不经过同一个 `prepare_model()` 观察点 + + +### 3. 当前实际 load 的 draft module 是谁 + +从本次运行已经走到 speculative verify 阶段可以判断: + +- draft worker 已经成功创建 +- draft model 已经成功 load +- draft model 不是完全没起来 + +结合当前注册关系,最合理的判断是: + +- target model:ATOM `DeepseekV3ForCausalLM` +- draft model:upstream SGLang `DeepseekV3ForCausalLMNextN` + +不是: + +- ATOM `DeepSeekMTP` + + +## 本次实验与过程记录 + +### 实验 1:总体调研 ATOM sglang plugin 中 MTP 的现状 + +动机: + +- 先搞清楚 plugin 里到底有哪些 speculative / MLA / MTP 相关代码 +- 避免一上来就在错误层面改代码 + +主要阅读文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attention_mla.py` +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` +- `ATOM/atom/plugin/register.py` +- `ATOM/atom/plugin/prepare.py` +- `ATOM/atom/models/deepseek_mtp.py` +- `ATOM/atom/spec_decode/eagle.py` +- `sglang/python/sglang/srt/models/deepseek_nextn.py` +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + +结果: + +- plugin 侧已经部分支持 speculative-aware 的 MLA 计算路径 +- 但没有完整接入 draft/MTP model +- 真正完整的 `DeepSeekMTP` 在 ATOM 原生链路里,不在当前 plugin draft 路径里 + + +### 实验 2:确认 target / draft 两条模型构造链 + +动机: + +- 理清 scheduler 里为什么既有 target worker 又有 draft worker +- 理清为什么 target model 和 draft model 不一定走同一套 model class + +关键结论: + +- speculative 模式下,运行时确实同时存在 target model 和 draft model +- 但二者最终还是都走通用 loader / `_initialize_model()` +- 区别在于: + - target 的 `ModelConfig` 正常走原始架构 + - draft 的 `ModelConfig` 会先做 `is_draft_model=True` 的架构改写 + + +### 实验 3:首个阻塞点 - `get_embed_and_head` 缺失 + +报错: + +- `AttributeError: 'DeepseekV3ForCausalLM' object has no attribute 'get_embed_and_head'` + +发生点: + +- `sglang/python/sglang/srt/speculative/eagle_worker.py` + +动机: + +- 需要确认这是 draft model 没构出来,还是 target/draft 接口不兼容 + +分析结果: + +- 不是 draft 没构出来 +- upstream speculative worker 在初始化 draft 时,要从 target model 取 + embedding / lm_head +- 但 `ATOM plugin wrapper` 没有把这几个接口暴露给外层 + + +### 实验 4:修复 wrapper 与 upstream speculative worker 的接口契约 + +改动文件: + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` + +改动动机: + +- 让 target wrapper 满足 upstream speculative worker 期待的最小接口 + +最终保留的最小接口: + +- `get_embed_and_head()` +- `set_embed_and_head()` +- `set_embed()` + +说明: + +- 一开始尝试把 `get_embed_and_head` 打到 inner `self.model` 上 +- 但 upstream 调的是外层 wrapper 对象 +- 所以最终改成正式的 wrapper 成员方法 + +结果: + +- `get_embed_and_head` 的报错被消除 +- 程序继续向前推进到了 speculative verify 阶段 + + +### 实验 5:新的阻塞点 - verify 路径 metadata 初始化错误 + +最新报错: + +- `AttributeError: 'NoneType' object has no attribute 'max'` + +调用栈末端: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +- `_init_extend_mla()` +- `forward_batch.extend_seq_lens.max().item()` + +初看怀疑: + +- 是不是 `batch` 为 `None` + +进一步定位后确认: + +- 不是 `batch` 为 `None` +- 也不是 `forward_batch` 为 `None` +- 真正为 `None` 的是: + - `forward_batch.extend_seq_lens` + +而且这在 `TARGET_VERIFY` 路径下是正常现象。 + + +## 为什么 `extend_seq_lens` 在 verify 里是 `None` + +文件: + +- `sglang/python/sglang/srt/speculative/eagle_info.py` +- `sglang/python/sglang/srt/managers/schedule_batch.py` + +`prepare_for_verify()` 会做: + +- 改 `batch.input_ids` +- 分配 `out_cache_loc` +- 更新 `req_to_token_pool` + +但不会去填普通 extend 用的 `extend_lens/extend_seq_lens`。 + +之后 `ScheduleBatch.get_model_worker_batch()` 会把 `self.extend_lens` +透传为 `extend_seq_lens`。 + +因此在 verify 路径下: + +- `extend_seq_lens=None` + +是完全可能且合理的。 + + +## 为什么 upstream 不会在这里崩 + +文件: + +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + +upstream 的 `AiterAttnBackend.init_forward_metadata()` 不是简单分成 +decode 和 extend 两大类,而是专门区分: + +- `decode_or_idle` +- `draft_extend` +- `target_verify` +- 普通 extend + +其中 `target_verify` 分支会自己根据: + +- `spec_info.draft_token_num` +- `forward_batch.seq_lens` + +来构造: + +- `qo_indptr` +- `kv_indptr` +- `kv_indices` + +它根本不依赖 `forward_batch.extend_seq_lens`。 + + +## 为什么 ATOM plugin 会在这里崩 + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +当前 ATOM plugin 这层把 metadata 初始化逻辑写成了: + +- `decode_or_idle` -> `_init_forward_metadata_decode()` +- 其他全部 -> `_init_forward_metadata_extend()` + +由于在 SGLang 里: + +- `ForwardMode.TARGET_VERIFY` 也被算作 `is_extend()` + +所以 verify 路径被误送进了普通 MLA extend 初始化: + +- `_init_extend_mla()` + +而这个函数又直接假设: + +- `forward_batch.extend_seq_lens` 一定存在 + +于是崩溃。 + +结论: + +- 当前问题不是“verify 输入准备坏了” +- 而是“ATOM plugin 缺 upstream 那段专门的 `TARGET_VERIFY` metadata 分支” + + +## 已做改动 + +### 文件 + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` + +### 实际改动 + +新增 wrapper 层方法: + +- `get_embed_and_head()` +- `set_embed_and_head()` +- `set_embed()` + +### 改动目的 + +- 对齐 upstream speculative worker 对 target/draft 共享 embedding/lm_head + 的接口期望 + +### 当前状态 + +- 这个改动已经验证有效 +- 运行不再卡在 `get_embed_and_head` 缺失 + + +## 当前未改的部分 + +本次刻意没有做这些事情: + +- 没有把 `DeepseekV3ForCausalLMNextN` 纳入 ATOM external package +- 没有给 plugin 接上 `ATOM DeepSeekMTP` +- 没有补 `TARGET_VERIFY` 的 metadata 初始化逻辑 +- 没有去动 draft model 的 attention backend + +原因: + +- 需要先把现有混合路径看清楚 +- 先分清楚是接口问题、metadata 问题,还是 draft model 架构问题 + + +## 目前推荐的后续推进顺序 + +### 第一步 + +先把 `TARGET_VERIFY` 在 `ATOM sgl_attn_backend.py` 中的 metadata 初始化补齐。 + +具体方向: + +- 参考 upstream + `sglang/python/sglang/srt/layers/attention/aiter_backend.py` + 的 `is_target_verify()` 分支 +- 不要再让 verify 走通用 `_init_extend_mla()` + + +### 第二步 + +验证当前混合路径是否可以完整跑完: + +- `draft -> target verify -> draft extend` + +如果这一步都跑不通,就还不适合开始切 draft 到 ATOM。 + + +### 第三步 + +在 draft 路径做架构选择: + +推荐方案: + +- 写一个 SGLang 兼容的薄 wrapper +- 内部复用 `ATOM/atom/models/deepseek_mtp.py` + +不推荐方案: + +- 在 plugin 里复制一整套新的 NextN / MTP 继承链 + + +## 关键文件索引 + +### ATOM 侧 + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` +- `ATOM/atom/plugin/register.py` +- `ATOM/atom/plugin/prepare.py` +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attention_mla.py` +- `ATOM/atom/models/deepseek_v2.py` +- `ATOM/atom/models/deepseek_mtp.py` +- `ATOM/launch_deepseek_mtp.sh` + +### upstream SGLang 侧 + +- `sglang/python/sglang/srt/configs/model_config.py` +- `sglang/python/sglang/srt/models/deepseek_nextn.py` +- `sglang/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py` +- `sglang/python/sglang/srt/speculative/eagle_worker.py` +- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` +- `sglang/python/sglang/srt/managers/scheduler.py` +- `sglang/python/sglang/srt/managers/tp_worker.py` +- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` +- `sglang/python/sglang/srt/speculative/eagle_info.py` + + +## 当前会话最终状态 + +- 已明确:当前 draft 不是 ATOM `DeepSeekMTP` +- 已明确:当前 draft 更可能是 upstream `DeepseekV3ForCausalLMNextN` +- 已明确:target 是 ATOM `DeepseekV3ForCausalLM` +- 已修复:wrapper 缺 `get_embed_and_head` 等接口的问题 +- 已定位:新的核心阻塞点是 `TARGET_VERIFY` metadata 初始化不完整 + +因此,本次工作最重要的阶段性成果是: + +- 把“到底是谁在跑 MTP / NextN” +- “当前失败发生在哪一层” +- “后面应该先补哪一段逻辑” + +这三件事彻底理清了。 diff --git a/work_log/MTP/MTP-2026-04-09.md b/work_log/MTP/MTP-2026-04-09.md new file mode 100644 index 000000000..1127bfb7d --- /dev/null +++ b/work_log/MTP/MTP-2026-04-09.md @@ -0,0 +1,715 @@ +# 2026-04-09 ATOM Plugin 模式下 DeepSeek MTP 接入与 CUDAGraph 调试记录 + +## 目标 + +本次工作的目标是继续推进 `ATOM + SGLang plugin` 路径下的 DeepSeek MTP 接入, +并重点解决下面几个问题: + +- 让 `ATOM plugin` 在不修改 upstream `sglang` 的前提下,真正接管 DeepSeek draft/MTP 路径 +- 确认 `SGLang` 当前在 plugin 模式下到底是怎样解析 draft model 的 +- 把 `ATOM/atom/models/deepseek_mtp.py` 绑定到 `SGLang` 期望的 draft model 接口上 +- 修复接入过程中出现的 runtime / speculative / CUDAGraph 相关问题 +- 记录 `ATOM` 与 `SGLang` 在 MTP 抽象上的差异,方便后续继续演进 + + +## 本次结论速览 + +- 当前已经实现: + - 在 `ATOM plugin` 中新增 `DeepseekV3ForCausalLMNextN` wrapper + - 通过 `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` 成功覆盖 upstream 的 draft model registry + - draft 路径已不再直接使用 upstream `sglang/srt/models/deepseek_nextn.py`,而是通过 plugin wrapper 绑定到 `ATOM DeepSeekMTP` +- 当前 draft 接线方式为: + - `SGLang` 仍然认为自己在加载一个独立的 draft model:`DeepseekV3ForCausalLMNextN` + - 但这个 draft model 的内部实现已经被 ATOM plugin 替换为 `ATOM/atom/models/deepseek_mtp.py` 里的 `DeepSeekMTP` +- 当前已经修过的几类问题: + - target wrapper 缺少 `get_embed_and_head / set_embed_and_head / set_embed` + - plugin runtime 中 target / draft 共用全局 `current_atom_config`,导致 MoE 串台 + - draft runtime layer id 使用了 checkpoint 全局层号,导致 KV cache layer index 越界 + - CUDAGraph 初始化时 `TARGET_VERIFY` / `DRAFT_EXTEND` 缺少 metadata 分支 + - `sglang` 的 `RadixAttention` 默认 `k_scale/v_scale` 在 CPU,plugin wrapper 需要显式搬到 CUDA +- 当前仍未完全解决的问题集中在: + - `CUDAGraph + TARGET_VERIFY + MLA decode` 路径下,某些传给 `aiter.mla_decode_stage1_asm_fwd` 的 metadata tensor 仍然落在 CPU + - 当前已明确抓到的一个具体问题是:`kv_last_page_lens(device=cpu, ...)` +- 一个很重要的最新判断: + - 从代码看,`ATOMAttnBackendForSgl` 对 `init_cuda_graph_state()` 的 override 本身是成功的 + - 更可能的问题不是 override 语义失败,而是后续 graph metadata 组装阶段把 `forward_metadata.kv_last_page_len` 绑定成了 CPU tensor + + +## 背景知识 + +### 1. upstream SGLang 的 DeepSeek MTP 组织方式 + +upstream `SGLang` 对 DeepSeek MTP 的处理方式是: + +- draft model 会被改写成一个独立的 model architecture +- 架构名是 `DeepseekV3ForCausalLMNextN` +- 实现文件是: + - `sglang/python/sglang/srt/models/deepseek_nextn.py` + +这意味着在 `SGLang` 看来,DeepSeek MTP / NextN 不是 target model 内部的一段辅助逻辑, +而是一个独立的 draft model 类,具备: + +- 自己的 `EntryClass` +- 自己的 `load_weights(..., is_nextn=True)` +- 自己的 `forward(...)` +- 和 speculative worker 的 embed/head 共享接口 + + +### 2. ATOM 的 DeepSeek MTP 组织方式 + +`ATOM` 里对应的实现是: + +- 文件: + - `ATOM/atom/models/deepseek_mtp.py` +- 类: + - `DeepSeekMTP` + +它更像一个 draft core,而不是一个 SGLang 风格的完整 runtime model 壳子。 + +它暴露的主要接口是: + +- `forward(input_ids, positions, hidden_states, ...)` +- `compute_logits(hidden_states, spec_step_idx=...)` + +也就是说: + +- upstream `SGLang`:偏“独立 draft model 壳子” +- `ATOM`:偏“独立 draft 计算模块” + + +### 3. 为什么 plugin 里需要 wrapper + +因为 `SGLang` 期望加载的是: + +- `DeepseekV3ForCausalLMNextN` + +而 `ATOM` 现成提供的是: + +- `DeepSeekMTP` + +所以 plugin 需要补一层 very thin wrapper,把: + +- `SGLang` 的 draft model 接口 + +映射到: + +- `ATOM DeepSeekMTP` + +这也是本次新增 `deepseek_nextn_wrapper.py` 的根本原因。 + + +### 4. `EntryClass` 在 SGLang external model package 中的作用 + +`SGLang` 的 external model package 机制不是靠显式调用 register API 完成的, +而是约定: + +- 遍历 `SGLANG_EXTERNAL_MODEL_PACKAGE` +- import 这个包下面的所有 module +- 读取每个 module 的 `EntryClass` +- 用 `EntryClass.__name__` 作为 architecture 名称注册进 `ModelRegistry` + +因此,只要: + +- 文件位于 `atom.plugin.sglang.models` +- module import 成功 +- 其中声明了: + - `EntryClass = [DeepseekV3ForCausalLMNextN]` + +那么 `SGLang` 就会用这个类覆盖 upstream 同名 architecture。 + + +### 5. `NEXTN`、`EAGLE`、`EAGLEWorker`、`EAGLEWorkerV2` 的关系 + +这一点在调试中非常关键。 + +当前启动脚本里使用的是: + +- `--speculative-algorithm NEXTN` + +但在 `SGLang` 中,这个参数会被进一步改写为: + +- `EAGLE` + +而最终选择哪个 worker,要看: + +- 是否开启 spec v2 / overlap schedule + +当前这次实验里没有开启: + +- `SGLANG_ENABLE_SPEC_V2=True` + +因此当前实际使用的是: + +- `sglang/python/sglang/srt/speculative/eagle_worker.py` + +而不是: + +- `eagle_worker_v2.py` + + +## 本次主要代码改动 + +### 1. 新增 draft wrapper + +文件: + +- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` + +新增类: + +- `DeepseekV3ForCausalLMNextN` + +目的: + +- 让 `SGLang` 在解析 draft architecture 时命中 plugin 自己的 wrapper +- 在不修改 upstream `sglang` 的前提下,把 draft model 内部实现切到 `ATOM DeepSeekMTP` + +该 wrapper 目前承担的职责包括: + +- 生成 plugin 模式下的 `atom_config` +- 将 config 改写为 `deepseek_mtp` / `DeepSeekMTPModel` 语义 +- 实例化 `ATOM/atom/models/deepseek_mtp.py::DeepSeekMTP` +- 调 `setup_deepseek_for_sglang()` 做 DeepSeek MLA patch +- 暴露: + - `get_embed_and_head()` + - `set_embed_and_head()` + - `set_embed()` +- `forward()` 中消费: + - `forward_batch.spec_info.hidden_states` +- `load_weights()` 时走: + - `load_model(..., spec_decode=True)` + + +### 2. plugin runtime scope 收口 + +文件: + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` + +新增 helper: + +- `plugin_runtime_scope(...)` + +目的: + +- 不修改 `ATOM` 非 plugin 目录下全局配置实现 +- 但在 plugin 层控制: + - 当前 framework + - 当前 atom_config + +动机: + +- target wrapper 和 draft wrapper 同时存在时,共用 `ATOM` 全局 runtime state +- 会导致: + - `current_atom_config` 串台 + - MoE 静态上下文读错实例 + +实际效果: + +- target/draft 的 `__init__ / forward / load_weights` 都在 plugin scope 中运行 +- 避免 draft 初始化后把 target runtime 全局状态永久污染 + + +### 3. target wrapper 中补齐 config 绑定 + +文件: + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` + +修改点: + +- 在 `atom.prepare_model(...)` 返回后,立即抓取当前 `atom_config` +- 如果 `self.model.atom_config` 不存在,则显式补上 + +动机: + +- 避免 `setup_deepseek_for_sglang()` 回退去读全局 `get_current_atom_config()` +- 在 runtime scope 退出后出现: + - `AssertionError: Current atom config is not set` + + +### 4. draft runtime layer id 重编号 + +文件: + +- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` + +新增逻辑: + +- `_retag_mtp_runtime_layer_ids(self.model)` + +动机: + +- `ATOM DeepSeekMTP` 的 checkpoint 语义使用全局层号: + - 如 `61`, `62`, ... +- 但 `SGLang` draft worker 给 draft KV cache 分配的 layer index 是本地层号: + - `0..num_nextn_layers-1` + +此前问题: + +- runtime `layer_id` 使用了 checkpoint/global layer id +- `token_to_kv_pool.set_kv_buffer(...)` 用这个 id 访问 draft KV buffer +- 出现: + - `IndexError: list index out of range` + +修法: + +- 保留 prefix / weight name 的全局层号语义 +- 仅把 runtime attention / radix attention / nested attn 的 `layer_id / layer_num` + 改为 draft-local layer index + + +### 5. 恢复 `config.json` 中 `num_hidden_layers` + +文件: + +- `ATOM/deepseek-ai/DeepSeek-R1-0528/config.json` + +改动: + +- 临时实验中曾将: + - `num_hidden_layers: 61 -> 16` +- 后来已恢复: + - `16 -> 61` + +结论: + +- 这个参数不能当成“简化实验”的随意开关 +- 它不仅影响 model topology,也影响: + - MTP layer weight naming + - runtime/global layer numbering + - KV cache / draft wrapper 语义 + + +### 6. plugin `RadixAttention` 中强制把 `k_scale / v_scale` 放到 CUDA + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/radix_attention.py` + +问题: + +- upstream `sglang` 的 `RadixAttention` 默认会把: + - `k_scale` + - `v_scale` + 建在 CPU 上 +- plugin wrapper 之前只在它们为 `None` 时才补 CUDA 参数 +- 但实际上这两个参数“不为 None,只是在 CPU 上” + +修法: + +- `None` 时创建 CUDA 参数 +- 已存在但不在 CUDA 时,也强制 `.to("cuda")` + +动机: + +- 避免 `mla_decode_fwd` 中把 CPU scale tensor 传进 `aiter` +- 触发: + - `aiter_tensor_t only supports CUDA tensors` + + +### 7. `sgl_attn_backend.py` 中补齐 speculative CUDAGraph metadata 分支 + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +问题: + +- plugin 侧普通 runtime 的 `init_forward_metadata(...)` 已经支持: + - `decode_or_idle` + - `draft_extend` + - `target_verify` +- 但 `CUDAGraph capture/replay` 专用的 metadata 初始化函数只支持: + - `decode_or_idle` + +因此在 graph capture 阶段遇到: + +- `ForwardMode.TARGET_VERIFY` + +会直接报: + +- `ValueError: Invalid mode: forward_mode=` + +修法: + +- 在 `init_forward_metadata_capture_cuda_graph()` +- 和 `init_forward_metadata_replay_cuda_graph()` + +中补上: + +- `TARGET_VERIFY` +- `DRAFT_EXTEND` + +对应的 metadata 初始化分支 + +注意: + +- 这次补分支时又额外发现 plugin 自己的 `ForwardMetadata` 签名和 upstream 不同 +- plugin 版本额外有两个必填位置参数: + - `page_table` + - `kv_lens` +- 因此需要在新补的分支中显式补 `None, None` + + +### 8. CUDAGraph 相关深度 debug instrumentation + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +- `aiter/aiter/mla.py` + +目的: + +- 当前问题只在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径上复现 +- 普通推理路径可以正常工作 +- 因此需要抓到: + - 到底是哪一个 tensor 在进入 `aiter` kernel 前还停留在 CPU + +具体做法: + +1. 在 plugin backend `_call_mla_decode_fwd()` 里增加 tensor state dump +2. 后续发现这层不足以定位内部派生参数,于是继续下沉到: + - `aiter/aiter/mla.py::mla_decode_fwd` +3. 在真正调用: + - `aiter.mla_decode_stage1_asm_fwd(...)` + 前,检查: + - `q` + - `kv_buffer` + - `qo_indptr` + - `kv_indptr` + - `kv_indices` + - `kv_last_page_lens` + - `num_kv_splits_indptr` + - `work_meta_data` + - `work_indptr` + - `work_info_set` + - `q_scale` + - `kv_scale` + +最终定位到: + +- `kv_last_page_lens(device=cpu, dtype=torch.int32, shape=(48,), is_cuda=False)` + + +### 9. 增加断言区分 graph buffer 初始化与 metadata 绑定问题 + +文件: + +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +新增 assert 位置: + +1. `init_cuda_graph_state()` 末尾 + - 断言: + - `self.cuda_graph_kv_last_page_len.is_cuda` +2. `init_forward_metadata_capture_cuda_graph()` / `replay` + - 在构造完 `ForwardMetadata` 后断言: + - `self.forward_metadata.kv_last_page_len is None or is_cuda` + +动机: + +- 区分问题到底是: + - graph state 初始化时就落到 CPU + - 还是后续 metadata 构造时又从别的来源拿了 CPU tensor + + +## 实验过程与关键观察 + +### 实验 1:确认 draft registry 是否已切到 ATOM wrapper + +动机: + +- 在真正调 runtime 之前,先确认 `ModelRegistry` 是否已经成功把 draft architecture 指到 plugin wrapper + +方法: + +- 通过 `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` +- 检查: + - `ModelRegistry.resolve_model_cls(["DeepseekV3ForCausalLMNextN"])` + +结果: + +- 初次尝试时由于 import 链问题失败,draft 仍落回 upstream +- 修正 import 时机与 framework runtime scope 后,registry 成功解析到: + - `atom.plugin.sglang.models.deepseek_nextn_wrapper.DeepseekV3ForCausalLMNextN` + + +### 实验 2:确认 `NEXTN` 实际走的是哪个 speculative worker + +动机: + +- 判断当前问题是 `eagle_worker` 还是 `eagle_worker_v2` 路径特有 + +结果: + +- `NEXTN` 在 `SGLang` 中会先映射成 `EAGLE` +- 当前未开启 spec v2 / overlap +- 所以实际使用的是: + - `sglang/python/sglang/srt/speculative/eagle_worker.py` + +不是: + +- `eagle_worker_v2.py` + + +### 实验 3:内存初始化阶段 `mem_fraction_static` 与脚本不一致 + +现象: + +- 脚本设置: + - `--mem-fraction-static 0.9` +- runtime 报错里看到: + - `self.server_args.mem_fraction_static=0.765` + +结论: + +- 不是脚本没生效 +- 而是 `SGLang` 在 AMD + `aiter` + 长上下文模式下,会再乘一层 `0.85` + +相关逻辑: + +- 若: + - `attention_backend == "aiter"` + - 且 `context_len > 8192` +- 则: + - `mem_fraction_static *= 0.85` + + +### 实验 4:MoE 串台问题 + +现象: + +- 接上 draft wrapper 后,target 路径的 `MoE` forward 开始报: + - `KeyError: 'model.layers.3.mlp.experts'` + +结论: + +- 不是 MTP wrapper 直接把 MoE 改坏了 +- 而是: + - target / draft 共用全局 `current_atom_config` +- draft 初始化把全局配置切成了 draft config +- target 的 MoE forward 再去读全局配置时,读到了错误的 `static_forward_context` + +修法: + +- 在 plugin 层引入 `plugin_runtime_scope(...)` +- 所有 plugin wrapper 的 init / forward / load 都显式切回自己的 runtime context + + +### 实验 5:MTP runtime layer id 越界 + +现象: + +- 将 `config.json` 中 `num_hidden_layers` 改成 `16` 后,出现: + - `IndexError: list index out of range` + +分析后确认: + +- 根因不是“改成 16”本身 +- 而是 runtime `layer_id` 错误地用了 checkpoint/global layer number +- draft worker 的 KV cache 只按 draft-local 层数分配 + +结论: + +- 运行时 `layer_id` 应该是: + - `0, 1, 2, ...` +- 而不是: + - `16`, `61`, `62`, ... + + +### 实验 6:CUDAGraph `TARGET_VERIFY` 分支缺失 + +现象: + +- 开启 cuda graph 初始化时,直接报: + - `ValueError: Invalid mode: forward_mode=` + +结论: + +- plugin 侧 CUDAGraph metadata 初始化漏了 `TARGET_VERIFY` / `DRAFT_EXTEND` +- 这是 plugin 侧缺口,不是 upstream `SGLang` 自身不支持 + + +### 实验 7:`aiter_tensor_t only supports CUDA tensors` + +现象: + +- 在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径下, + `mla_decode_stage1_asm_fwd(...)` 报: + - `aiter_tensor_t only supports CUDA tensors` + +初步怀疑: + +- 可能是 plugin wrapper 的 `k_scale / v_scale` 还在 CPU + +部分修复: + +- 已在 plugin `RadixAttention` 中把 scale tensor 统一搬到 CUDA + +进一步深挖: + +- 下沉到 `aiter/aiter/mla.py` +- 发现真正触发断言的 tensor 为: + - `kv_last_page_lens(device=cpu, ...)` + + +### 实验 8:`kv_last_page_lens` 为 CPU 的进一步判断 + +关键观察: + +- plugin `ATOMAttnBackendForSgl.init_cuda_graph_state()` 中, + `self.cuda_graph_kv_last_page_len` 是按 `device=self.device` 创建的 +- 因此不太像是“子类 override 根本没生效” + +当前更强的判断是: + +- `forward_batch.attn_backend` 大概率仍然是 `ATOMAttnBackendForSgl` +- 但后续在 graph metadata 绑定 / 组装阶段,`forward_metadata.kv_last_page_len` + 被绑定成了 CPU tensor +- 也不排除: + - 某条复用父类 `forward_decode` 的路径里使用了来自父类默认初始化的 CPU graph buffer + +当前状态: + +- 已加断言区分: + - graph state 初始化阶段 + - 与 metadata 绑定阶段 +- 但在本次会话结束时,尚未拿到最终触发哪一个断言的最新日志 + + +## 当前对整体架构的理解 + +### 1. target 与 draft 在 plugin 模式下的实际形态 + +当前链路中: + +- target model: + - 由 `base_model_wrapper.py` 暴露为 `DeepseekV3ForCausalLM` + - 内部仍是 `ATOM DeepseekV3ForCausalLM` +- draft model: + - 由 `deepseek_nextn_wrapper.py` 暴露为 `DeepseekV3ForCausalLMNextN` + - 内部被绑定到 `ATOM DeepSeekMTP` + + +### 2. SGLang 与 ATOM 在 MTP 抽象上的差异 + +可以这样概括: + +- `SGLang` + - 把 MTP / NextN 视为一个独立的 runtime model + - draft worker 会单独初始化这个 model +- `ATOM` + - 把 MTP 实现成一个独立 draft core / module + - 需要由 speculative runtime 或 plugin wrapper 再包一层 + +因此: + +- `SGLang` 的差异在“接口层” +- `ATOM` 的差异在“实现层” + + +### 3. plugin 当前真正做的事情 + +当前 plugin 并不是直接修改 upstream `sglang` 的 DeepSeek MTP 逻辑,而是在三个层面做了替换: + +1. registry 层: + - 用 `EntryClass` 覆盖 `DeepseekV3ForCausalLMNextN` +2. wrapper 层: + - 用 SGLang 兼容壳把 `DeepSeekMTP` 暴露成 draft model +3. runtime 层: + - 补齐: + - embed/head sharing + - speculative hidden_states 输入 + - spec_decode 权重加载 + - runtime layer id 重编号 + - plugin runtime scope + + +## 当前仍存在的问题 + +截至本次记录结束,仍然有以下未完全解决的问题: + +- `CUDAGraph + TARGET_VERIFY + MLA decode` 路径仍存在 graph-only bug +- 当前最具体的线索是: + - `kv_last_page_lens` 在进入 `aiter.mla_decode_stage1_asm_fwd` 前为 CPU tensor +- 尚未最终确认这个 CPU tensor 是: + - 来自 graph state 初始化未正确走 plugin override + - 还是 metadata 构造过程中被其它路径重新绑定 +- 当前还没有完成的验证是: + - 让新增 assert 真正触发并给出第一现场 + + +## 本次新增 / 修改文件清单 + +### 新增 + +- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` +- `ATOM/work_log/MTP/MTP-2026-04-09.md` + +### 修改 + +- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` +- `ATOM/atom/plugin/sglang/attention_backend/radix_attention.py` +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` +- `aiter/aiter/mla.py` +- `ATOM/deepseek-ai/DeepSeek-R1-0528/config.json` + + +## 对后续工作的建议 + +### 1. 先把 CUDAGraph 的 `kv_last_page_len` 问题钉死 + +建议下一步只做一件事: + +- 继续跑一次带最新 assert 的启动 +- 看究竟是: + - `init_cuda_graph_state()` 断言触发 + - 还是 `ForwardMetadata` 组装后的断言触发 + +这样可以把问题准确收敛到: + +- graph buffer 初始化层 +- 或 metadata 绑定层 + + +### 2. 不要再直接改 `config.json::num_hidden_layers` + +如果需要降低实验复杂度,更建议: + +- 调小: + - `--context-length` + - `--max-running-requests` + - `--chunked-prefill-size` + - `--max-total-tokens` + +而不是直接改: + +- `num_hidden_layers` + + +### 3. 区分“普通 speculative 能跑”和“CUDAGraph 也能跑” + +当前已经能说明: + +- 普通 speculative 路径和 graph 路径不是同一个问题集合 +- graph 路径会更早暴露: + - metadata device 问题 + - verify-only 分支缺失 + - graph persistent buffer 设备不一致 + +因此后续调试时建议始终把问题分成两类: + +- 非 graph runtime bug +- CUDAGraph 专有 bug + + +## 一句话总结 + +本次工作的核心进展不是“DeepSeek MTP 已完全跑通”,而是: + +- 已经把 draft architecture 从 upstream `DeepseekV3ForCausalLMNextN` + 成功接到了 `ATOM DeepSeekMTP` +- 已经把 plugin runtime scope、MTP runtime layer id、speculative metadata 等关键适配层基本搭起来 +- 当前剩余阻塞点主要集中在 `CUDAGraph + TARGET_VERIFY + MLA decode` 的 graph-only 设备与 metadata 一致性问题 + diff --git a/work_log/MTP/MTP-2026-04-10.md b/work_log/MTP/MTP-2026-04-10.md new file mode 100644 index 000000000..466bf8562 --- /dev/null +++ b/work_log/MTP/MTP-2026-04-10.md @@ -0,0 +1,801 @@ +# 2026-04-10 ATOM Plugin 模式下 DeepSeek MTP 的 CUDAGraph 调试、修复与知识沉淀 + +## 目标 + +今天这轮工作的目标主要有两条: + +1. 继续推进 `ATOM plugin + SGLang + DeepSeek MTP` 路径下的 `CUDAGraph` 调试,重点把昨天留下的 graph-only 问题继续收敛。 +2. 把今天在调试过程中澄清的一些关键背景知识整理成系统化文档,方便后续继续做 MTP / speculative / CUDAGraph 相关工作时复盘和学习。 + +今天聚焦的问题主要有两个: + +- 为什么此前在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径上,`kv_last_page_lens` 会落在 CPU 上。 +- 为什么后续修掉 CPU tensor 问题后,又在 `draft_extend replay` 路径上遇到 `bs=1` 但 `seq_lens.shape[0]=48` 的 shape mismatch。 + + +## 本次结论速览 + +今天最关键的结论有四个: + +1. `kv_last_page_lens(device=cpu)` 的根因不是简单的“plugin backend override 没生效”,而是: + - `EAGLE draft` 路径里使用的 `AiterMultiStepDraftBackend` + - 在其内部直接实例化了 upstream `AiterAttnBackend` + - 绕过了 plugin 通过 `"aiter"` 名字注册的 `ATOMAttnBackendForSgl` + - 结果 upstream 默认的 CPU `cuda_graph_kv_last_page_len` 泄漏进 graph 路径 + +2. `draft_extend replay` 的 `bs=1` 但 `seq_lens.shape[0]=48` 报错,本质上是: + - replay 选中的 graph bucket 是 `1` + - 但 `draft_extend cuda graph runner` 把整块静态 buffer 传给了 backend + - plugin backend 的 `draft_extend replay` 分支最初没有像 upstream 一样先做 `seq_lens[:bs]` 和 `accept_length[:bs]` 的规整 + +3. 修这个 `draft_extend replay` 问题,更合理的做法不是在函数入口打一层粗粒度的统一切片补丁,而是: + - 让 plugin 的 `init_forward_metadata_replay_cuda_graph()` 中 + - `forward_mode.is_draft_extend()` 这个分支 + - 在语义上与 upstream `AiterAttnBackend` 的对应分支对齐 + +4. 今天补充出的两篇专题文档,把下面几类知识基本梳理清楚了: + - `CUDAGraph` 在 `SGLang` 中到底固定什么 + - 为什么 `decode` 更适合做 graph + - 为什么普通 `prefill / extend` 更难 graph 化 + - `raw_bs / bs / num_tokens` + - `qo_indptr / kv_indptr / kv_indices` + - 特别是 `kv_indices` 的物理含义、shape、和 `req_to_token` 的对应关系 + + +## 今天开始时的上下文 + +昨天的工作已经完成了几件重要事情: + +- plugin draft wrapper 已接到 `ATOM DeepSeekMTP` +- `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` 已经生效 +- `TARGET_VERIFY / DRAFT_EXTEND` 的 graph metadata 分支已经在 plugin backend 中补齐 +- `k_scale / v_scale` 的 CPU tensor 问题已经修过 + +但昨天留下的核心阻塞点是: + +- `CUDAGraph + TARGET_VERIFY + MLA decode` 路径里 +- `aiter.mla_decode_stage1_asm_fwd(...)` 之前仍然有 metadata tensor 落在 CPU +- 明确抓到的现场是: + - `kv_last_page_lens(device=cpu, dtype=torch.int32, shape=(48,), is_cuda=False)` + +昨天的一个判断是: + +- `ATOMAttnBackendForSgl.init_cuda_graph_state()` override 本身大概率是成功的 +- 更可能的问题在 graph metadata 组装或者某条 graph-only 调用链里 + +今天的工作就是在这个基础上继续收敛。 + + +## 背景知识补充 + +### 1. 为什么今天的两个 bug 都是 graph-only bug + +今天碰到的两个 bug: + +- `kv_last_page_lens` 落在 CPU +- `draft_extend replay` 的 `bs=1` / `seq_lens.shape[0]=48` + +都有一个共同特点: + +- 普通 eager runtime 或普通 speculative 路径不一定暴露 +- 但在 `cuda graph capture / replay` 路径上会被迅速放大 + +原因在于 graph 路径有两层额外约束: + +1. 需要静态持久 buffer +2. 需要严格区分: + - 当前 bucket 的有效视图 + - 底层静态 backing buffer + +graph-only bug 往往不是“attention 算法错了”,而是: + +- backend 选型不对 +- graph state 的 persistent buffer 设备不对 +- replay 时 view 与 backing buffer 混淆 + +### 2. `raw_bs`、`bs`、`num_tokens` + +今天反复用到的三个概念: + +- `raw_bs` + - 真实 batch 中当前有多少个 request +- `bs` + - 这次 replay 选中的 graph bucket 大小 +- `num_tokens` + - 本轮真正参与 forward 的 token 数 + +这三者在 graph 路径中不一定相等。 + +特别是在 speculative 路径中: + +- `draft decode` + - `num_tokens = bs * topk` +- `draft extend` + - `num_tokens = bs * (speculative_num_steps + 1)` +- 普通 prefill + - 常常是 `sum(extend_seq_lens)` + +### 3. `kv_indices` 的意义 + +今天在整理文档时,又把 `kv_indices` 这类字段重新梳理了一遍。 + +一句话: + +- `kv_indices` 不是逻辑 token 下标 +- 它是当前这轮 attention 真正要访问的 **physical KV slot id 列表** + +它和下面几个量一起理解最清楚: + +- `req_pool_indices` +- `req_to_token` +- `seq_lens` +- `kv_indptr` + +也就是: + +- 先按 `req_pool_indices` 找到 request 在 `req_to_token` 中的那一行 +- 再按 `seq_lens[i]` 取出这行前面的有效 token -> physical slot 映射 +- 最后拼成一个一维 flatten 数组 + +这个理解对于看懂 `create_flashinfer_kv_indices_triton(...)` 非常重要。 + + +## 问题一:`kv_last_page_lens` 在 CPU 上 + +### 现象 + +从 `log.serve.log` 和之前在 `aiter/mla.py` 加的 debug 可见: + +- 进入 `mla_decode_stage1_asm_fwd` 前 +- 只有 `kv_last_page_lens` 仍然是 CPU tensor +- 其他关键张量如: + - `q` + - `kv_buffer` + - `qo_indptr` + - `kv_indptr` + - `kv_indices` + - `work_metadata` + - `q_scale` + - `kv_scale` + 都已经是 CUDA tensor + +### 初始猜想 + +最初怀疑的方向有两个: + +1. plugin backend 的 `init_cuda_graph_state()` override 根本没生效 +2. override 生效了,但 graph metadata 组装阶段又把 `forward_metadata.kv_last_page_len` 绑回了别处的 CPU tensor + +### 继续追链路后的发现 + +今天顺着 `EAGLE draft` 的 graph 初始化链继续往下看,发现真正关键的地方是: + +- `EAGLEWorker.init_attention_backend()` + - 会通过 `DraftBackendFactory.create_decode_backend()` 构造 draft 的 decode backend +- 当 backend 选择是 `"aiter"` 时 + - `DraftBackendFactory._create_aiter_decode_backend()` + - 会直接实例化 `AiterMultiStepDraftBackend` +- 而 `AiterMultiStepDraftBackend.__init__()` 内部又会直接 new: + - `AiterAttnBackend(...)` + +问题就在这里: + +- 这条 direct construction 没走 attention registry +- 所以 plugin 在 `"aiter"` 名字上注册的 `ATOMAttnBackendForSgl` + 并不会自动生效 + +换句话说: + +- target 路径可能走的是 plugin backend +- 但 draft multi-step graph 路径内部某些 step backend 仍然是 upstream `AiterAttnBackend` + +### 为什么这会导致 CPU `kv_last_page_lens` + +上游 `AiterAttnBackend.init_cuda_graph_state()` 里有: + +- `self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)` + +也就是: + +- 默认建在 CPU 上 + +而 plugin 版已经修成: + +- `torch.ones(max_bs, dtype=torch.int, device=self.device)` + +因此,一旦 `AiterMultiStepDraftBackend` 内部 step backend 实际还是 upstream 版本: + +- graph state 里的 `cuda_graph_kv_last_page_len` + 就是 CPU tensor + +这就是之前 `mla_decode_fwd` 抓到: + +- `kv_last_page_lens(device=cpu, ...)` + +的根因。 + +### 修法 + +修法放在 plugin 层,不改 upstream `sglang`: + +- 文件: + - `ATOM/atom/plugin/register.py` + +做法: + +- 继续注册 `"aiter" -> ATOMAttnBackendForSgl` +- 同时 monkeypatch: + - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` + - 让 direct import / direct construction 也落到 plugin backend + +核心代码: + +- import upstream `aiter_backend` 模块 +- `sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl` + +### 为什么这个修法合理 + +因为这个问题的本质不是: + +- metadata 算法错 + +而是: + +- graph 路径内部实际跑的 backend 实例不对 + +所以应该在 plugin 注入层统一修 backend 选型,而不是在更下游继续 patch 每个 graph state 字段。 + +### 验证 + +补了对应单测: + +- 文件: + - `ATOM/tests/plugin/test_sglang_register.py` + +新增检查: + +- 除了验证 `"aiter"` 名字的 registry 绑定 +- 还显式验证: + - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` + - 也被换成了 plugin backend + +执行: + +- `pytest -q ATOM/tests/plugin/test_sglang_register.py` + +结果: + +- `9 passed` + + +## 问题二:`draft_extend replay` 中 `bs=1`,但 `seq_lens.shape[0]=48` + +### 现象 + +在下一轮服务运行中,又遇到一个新的 graph-only 错误: + +- `draft_extend replay` +- `init_forward_metadata_replay_cuda_graph()` +- `kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)` +- 报: + - `Target sizes: [1]` + - `Tensor sizes: [48]` + +也就是: + +- 当前 replay bucket 是 `bs = 1` +- 但传进来的 `seq_lens` 仍然是长度 `48` 的静态 backing buffer + +### 为什么会这样 + +`draft_extend cuda graph runner` 在 replay 时确实是这样传参的: + +- 会先把真实 batch 数据 copy 到静态 buffer 的前缀 +- 但调用 backend 时,传的是整块 `buffers.seq_lens` / `buffers.req_pool_indices` +- 而不是 `[:bs]` 视图 + +这意味着: + +- caller 传的是 backing buffer +- callee 却按“当前 bucket 的有效 view”去理解 + +于是就会出现: + +- 左边 slice 长度是 `1` +- 右边 `cumsum(seq_lens)` 长度是 `48` + +### 一开始的临时修法 + +我最初为了快速兜住问题,尝试过: + +- 在 plugin backend 的 `init_forward_metadata_replay_cuda_graph()` 入口 +- 对 `req_pool_indices / seq_lens / seq_lens_cpu` + 统一做 `[:bs]` + +这个做法能修掉 shape mismatch,但后来用户指出了一个更重要的问题: + +- 这里不应该只做粗粒度防御性 patch +- 更应该看 upstream 同分支是怎么处理的 + +这是对的。 + +### 继续对 upstream 后的发现 + +upstream `AiterAttnBackend.init_forward_metadata_replay_cuda_graph()` 的 +`draft_extend` 分支里,语义并不只是: + +- `seq_lens = seq_lens[:bs]` + +还包括: + +- `accept_lens = spec_info.accept_length[:bs]` +- `qo_indptr[1:] = cumsum(accept_lens)` + +也就是说,upstream 在 replay 阶段的 query 分段不是简单固定步长, +而是和 `accept_length` 绑定。 + +### 最终修法 + +因此最终采用的不是“函数入口统一切片”的粗 patch,而是: + +- 让 plugin 的 + - `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + 中 + - `init_forward_metadata_replay_cuda_graph()` + 的 `forward_mode.is_draft_extend()` 分支 + 在语义上对齐 upstream + +具体做了两件事: + +1. `seq_lens = seq_lens[:bs]` +2. `accept_lens = spec_info.accept_length[:bs]` +3. `qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)` + +并保留 plugin 自己的: + +- `ForwardMetadata` 字段布局 +- MLA persistent kernel metadata 生成 +- `kv_last_page_len.is_cuda` 断言 + +### 一个小插曲:`qo_indptr[0] = 0` + +在第一次改这个分支时,我曾额外加过: + +- `qo_indptr[0] = 0` + +后来和 upstream 对比后又删掉了,原因是: + +- upstream 没有这行 +- `self.qo_indptr` 在父类中本来就是 `torch.zeros(...)` 初始化 +- 这类 CSR/indptr buffer 在很多地方都只写 `1:` + +因此: + +- 这行虽然防御性上没错 +- 但既然目标是与 upstream 严格语义对齐,就不应额外保留 + + +## 关于“为什么 graph 路径里 metadata 可以每轮重建” + +今天还专门澄清了一个很重要的问题: + +- graph replay 里 metadata 明明每轮都会重新构造 +- 为什么 decode / draft_extend 还能做 graph +- 而 prefill 更难 + +### 结论 + +不是“metadata 变了就不能 graph”,而是要区分: + +1. metadata 的**内容**变 +2. metadata 是否会进一步影响: + - graph 内部分支 + - 中间 tensor shape + - workspace shape + - kernel launch 形态 + +对于 decode / draft_extend graph 路径: + +- metadata 的构造发生在 graph 外 +- graph 内看到的是固定地址、固定 shape 的 persistent buffer +- replay 只是改这些 buffer 的内容 + +因此 graph 仍然可复用。 + +而普通 prefill: + +- `total_tokens` +- `max_q_len` +- `max_kv_len` +- `qo_indptr` +- `kv_indptr` +- workspace 大小 +都可能跟 batch 一起变化 + +这时 metadata 已经不只是“参数”,而更像“问题几何结构”的一部分。 + +所以: + +- 普通 prefill 仍然难 graph +- 即使你不考虑不同 kernel path + + +## 今天补出的三篇背景文档 + +今天除了修 bug,还补了三份面向复盘的文档。 + +### 1. `2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md` + +主题: + +- `CUDAGraph` 在 `SGLang` 中固定的到底是什么 +- 为什么 `decode` 更适合做 graph +- 为什么普通 `prefill / extend` 难 graph +- `raw_bs / bs / num_tokens` +- `SGLang` 在 decode 阶段怎样做 capture / replay +- `ForwardMetadata` 在 graph capture / replay 中扮演什么角色 + +适合什么时候看: + +- 想从整体架构层面理解 `graph + attention metadata` + +### 2. `2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md` + +主题: + +- 只收窄到“最简单 prefill” +- 不考虑不同 kernel path +- 不考虑 prefix cache +- 不考虑 speculative +- 解释: + - 为什么即使是最简单 prefill 也难 graph + - `qo_indptr / kv_indptr / kv_indices / max_q_len / max_kv_len` + 的 shape 与物理意义 + - 给多个可手算例子 + +特别补充了: + +- `kv_indices` 的详细解释 +- `req_to_token` +- `req_pool_indices` +- 以及 `qo_indptr + kv_indptr + kv_indices` 联合看时的完整例子 + +适合什么时候看: + +- 想快速回忆 metadata 的 shape 和物理含义 + +### 3. `MTP-2026-04-10.md` + +也就是本文件,作为今天的完整工作日报。 + + +## 今天的代码改动清单 + +### 1. `ATOM/atom/plugin/register.py` + +改动: + +- 除了继续通过 registry 把 `"aiter"` 绑定到 `ATOMAttnBackendForSgl` +- 还显式把 upstream 模块符号: + - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` + 重新绑定到 plugin backend + +目的: + +- 修复 `AiterMultiStepDraftBackend` 内部 direct construction 绕过 registry 的问题 + +### 2. `ATOM/tests/plugin/test_sglang_register.py` + +改动: + +- 扩展测试覆盖 +- 不仅验证 registry name 是 `"aiter"` +- 也验证 `AiterAttnBackend` 模块符号被替换成了 plugin backend + +目的: + +- 防止以后又出现“registry 绑了,但 direct construction 仍然绕过 plugin”的回归 + +### 3. `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +改动: + +- 继续保留 graph metadata 断言 +- 把 `init_forward_metadata_replay_cuda_graph()` 中 + `forward_mode.is_draft_extend()` 分支 + 调整为更接近 upstream 的语义: + - `seq_lens = seq_lens[:bs]` + - `accept_lens = spec_info.accept_length[:bs]` + - `qo_indptr[1:] = cumsum(accept_lens)` + - `kv_indptr[1:] = cumsum(seq_lens)` + +目的: + +- 修复 `draft_extend replay` 中 + - `bs=1` + - 但 `seq_lens.shape[0]=48` + 的 graph-only mismatch + +### 4. 新增文档 + +- `ATOM/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md` +- `ATOM/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md` +- `ATOM/work_log/MTP/MTP-2026-04-10.md` + + +## 今天做过但没有保留的尝试 + +今天有一个短暂尝试后来撤回了: + +- 我曾经直接修改过 upstream: + - `sglang/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py` +- 试图在 caller 侧把传给 backend 的 `buffers.seq_lens` / `buffers.req_pool_indices` + 改成 `[:bs]` + +这个改法从技术上能工作,但不符合这次工作的原则: + +- 用户要求不要改 upstream `sglang` +- 这类修法也会把 plugin 自己对 graph 语义的偏差掩盖掉 + +因此后来把它撤回了,最终修法完全落在 plugin 内部。 + + +## 实验过程与结果 + +### 实验 1:确认 `kv_last_page_lens` 问题是不是 graph state 初始化失败 + +动机: + +- 区分问题到底是: + - plugin backend 的 graph state 初始化失败 + - 还是后续调度/metadata 链路里又掉回 upstream backend + +方法: + +- 检查 plugin backend 中 `init_cuda_graph_state()` 的实现 +- 对照 upstream `AiterAttnBackend` +- 追 `EAGLEWorker -> DraftBackendFactory -> AiterMultiStepDraftBackend` + 的实例化链路 + +结果: + +- 发现不是简单的 override 失败 +- 而是 `AiterMultiStepDraftBackend` 内部 direct new `AiterAttnBackend` + 绕过了 plugin registry + +结论: + +- 这是 backend 注入点不完整导致的 graph-only bug + +### 实验 2:确认修 backend 注入后,下一处报错落在哪里 + +动机: + +- 修完 CPU tensor 问题后,需要继续看下一层 graph-only 问题 + +方法: + +- 根据新的 `log.serve.log` 栈追 `draft_extend replay` + +结果: + +- 抓到新的错误: + - `kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)` + - `Target sizes: [1]` + - `Tensor sizes: [48]` + +结论: + +- graph runner 的 backing buffer 和当前 bucket view 混淆了 + +### 实验 3:先做临时入口兜底,再回看 upstream 语义 + +动机: + +- 先快速验证 shape mismatch 是否确实是 view 问题 + +方法: + +- 在 plugin `init_forward_metadata_replay_cuda_graph()` 入口加统一 `[:bs]` + 规整 + +结果: + +- 从逻辑上能兜住这类 mismatch + +但随后进一步对照 upstream 发现: + +- 这个问题不只是切片 +- 还牵涉到 `draft_extend replay` 下 `qo_indptr` 的语义应该由 + `accept_length` 决定 + +结论: + +- 函数入口打补丁只是临时兜底 +- 更好的修法是按 upstream 分支语义对齐 + +### 实验 4:对齐 plugin `draft_extend replay` 分支到 upstream + +动机: + +- 让 plugin 分支和 upstream 的 graph replay 语义一致 + +方法: + +- 把 plugin 的 `draft_extend replay` 分支改为: + - `seq_lens = seq_lens[:bs]` + - `accept_lens = spec_info.accept_length[:bs]` + - `qo_indptr[1:] = cumsum(accept_lens)` + +结果: + +- 代码语义与 upstream 一致性明显更强 +- 且不需要修改 upstream `sglang` + +### 实验 5:运行服务并看最新日志 + +动机: + +- 看修复后服务是否恢复到正常请求处理状态 + +结果: + +- `log.serve.log` 最新尾部可见: + - Prefill batch 仍显示: + - `cuda graph: False` + - Decode batch 显示: + - `cuda graph: True` + - 多个请求成功返回 `200 OK` + +这至少说明: + +- 当前服务已经重新进入了正常请求处理状态 +- prefill 依旧不走 graph,这与当前系统设计一致 +- decode graph 已经在工作 + +需要注意: + +- 今天没有做系统化 benchmark +- 这里只能说明日志上服务在继续跑,不能说明所有 corner case 都完全验证完成 + + +## 今天形成的理解:为什么 decode graph 可以,而 prefill 更难 + +今天围绕用户问题,又把这件事重新总结了一遍。 + +### 1. decode graph 为什么更自然 + +因为 decode 常常满足: + +- 每个 request 每轮 query token 数固定 +- `num_tokens_per_bs` 固定 +- `max_q_len` 常常固定为 `1` +- metadata 更多是在固定地址上的输入数据 + +因此: + +- 可以通过 `bs bucket + 静态 buffer + metadata 重建` + 来复用 graph + +### 2. prefill graph 为什么更难 + +即使不考虑 prefix cache 和不同 kernel path,最简单 prefill 也会遇到: + +- `total_tokens = sum(extend_seq_lens)` 会变 +- `qo_indptr` 会变 +- `max_q_len` 会变 +- `max_kv_len` 会变 +- 中间 tensor / workspace shape 也会变 + +所以它的 challenge 不只是: + +- metadata 值变了 + +而是: + +- metadata 连同问题几何结构一起变了 + +这会让: + +- graph 内部张量 shape +- workspace 形状 +- 有时连 kernel launch 计划 +都跟着变化 + + +## 当前状态判断 + +截至今天结束,可以比较有把握地说: + +1. DeepSeek MTP 的 plugin draft 路径接线已经比昨天更稳固: + - direct construction 绕过 plugin backend 的坑已经堵住 + +2. `draft_extend replay` 的 graph metadata 语义也更接近 upstream: + - 不再是一个仅靠入口切片兜底的 patch + - 而是按 upstream 分支语义修正 + +3. 从最新日志看: + - prefill 继续按设计走 `cuda graph: False` + - decode 已经在 `cuda graph: True` + - 服务能正常响应请求 + +4. 但今天没有完成的事情仍然有: + - 没有做系统化压测或 benchmark + - 没有确认所有 speculative 相关 corner case 都已经覆盖 + - 没有进一步处理“普通 prefill 是否值得 graph 化”的工程设计 + + +## 本次执行过的验证命令 + +今天执行过的本地验证主要包括: + +- `pytest -q ATOM/tests/plugin/test_sglang_register.py` + - 结果:通过 +- `python3 -m py_compile` + - 对修改过的 plugin 文件做语法检查 + - 结果:通过 + +另外: + +- 通过 `log.serve.log` 持续跟踪服务运行结果 +- 从日志尾部确认: + - `Prefill batch ... cuda graph: False` + - `Decode batch ... cuda graph: True` + - 多个 `/v1/completions` 返回 `200 OK` + + +## 对后续工作的建议 + +### 1. 继续观察真实服务日志 + +虽然今天的两个 graph-only bug 已经定位并修正,但建议继续跑一段时间,看是否还有新的: + +- speculative-only +- graph-only +- replay-only + +问题继续冒出。 + +### 2. 如果再出现 graph-only 问题,优先检查两类契约 + +今后的 graph 调试,优先看: + +1. backend 注入契约 + - 实际实例是不是 plugin backend +2. backing buffer / bucket view 契约 + - caller 传的是整块静态 buffer,还是当前 `[:bs]` view + +很多 graph-only bug 最后都归到这两类。 + +### 3. 如需进一步探索 prefill graph,可先限定一个最简单子场景 + +如果未来要继续研究: + +- “普通 prefill 是否也能做 graph” + +建议不要直接想做“所有 prefill graph”,而是先限定: + +- 无 prefix cache +- 固定 kernel path +- 小量 bucket + +然后再评估: + +- `total_tokens` +- `max_q_len` +- `max_kv_len` +- workspace + +能否通过 bucket 化或 pad/unpad 收敛。 + + +## 一句话总结 + +今天最核心的进展不是“DeepSeek MTP 的所有问题都已解决”,而是: + +- 把 `kv_last_page_lens` 掉到 CPU 的 graph-only 根因准确收敛到 + `AiterMultiStepDraftBackend` 绕过 plugin backend +- 把 `draft_extend replay` 的 `bs=1 / seq_lens=48` 问题从临时补丁, + 收敛为与 upstream 语义对齐的修法 +- 同时把 `CUDAGraph / decode / prefill / metadata / kv_indices` + 这一整套背景知识整理成了更适合后续复盘的文档体系 From cb3f4fc636c7d983956878978fc0769892e2043f Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 23 Apr 2026 08:47:48 +0000 Subject: [PATCH 3/4] adopt new attn constructor args --- atom/plugin/sglang/attention_backend/sgl_attn_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py index 2a26f8083..4b4f82972 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/sgl_attn_backend.py @@ -216,8 +216,9 @@ def __init__( model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, + topk: int = 1, ): - super().__init__(model_runner, skip_prefill, kv_indptr_buf) + super().__init__(model_runner, skip_prefill, kv_indptr_buf, topk) mapping = getattr( model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None ) From e07ea02cd3f6f573f04dbcca91bb0a9b71774c85 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 24 Apr 2026 08:02:23 +0000 Subject: [PATCH 4/4] rm worklog --- ...deepseek-speculative-attention-metadata.md | 721 ---------- ...8-sglang-attention-backend-fields-guide.md | 908 ------------ ...026-04-08-sglang-kv-cache-storage-guide.md | 692 ---------- ...glang-speculative-decoding-architecture.md | 910 ------------ .../2026-04-08-vllm-continuous-batching.md | 1214 ----------------- ...cudagraph-prefill-decode-metadata-guide.md | 734 ---------- ...simple-prefill-cudagraph-metadata-guide.md | 866 ------------ work_log/MTP/MTP-2026-04-08.md | 525 ------- work_log/MTP/MTP-2026-04-09.md | 715 ---------- work_log/MTP/MTP-2026-04-10.md | 801 ----------- 10 files changed, 8086 deletions(-) delete mode 100644 work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md delete mode 100644 work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md delete mode 100644 work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md delete mode 100644 work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md delete mode 100644 work_log/MTP/2026-04-08-vllm-continuous-batching.md delete mode 100644 work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md delete mode 100644 work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md delete mode 100644 work_log/MTP/MTP-2026-04-08.md delete mode 100644 work_log/MTP/MTP-2026-04-09.md delete mode 100644 work_log/MTP/MTP-2026-04-10.md diff --git a/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md b/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md deleted file mode 100644 index bc2e136a6..000000000 --- a/work_log/MTP/2026-04-08-deepseek-speculative-attention-metadata.md +++ /dev/null @@ -1,721 +0,0 @@ -# 2026-04-08 DeepSeek Speculative 与 Attention Metadata 关系笔记 - -## 文档目的 - -本文专门解释一个在调试 DeepSeek speculative / MTP 时非常关键、但又很容易被低估的问题: - -- **speculative decoding 和 attention metadata 到底是什么关系?** - -对于 DeepSeek 这类 MLA 模型来说,这个问题尤其重要。因为 speculative decoding -不是简单地“多跑一个 draft model”,它会直接改变: - -- 当前 batch 有多少 query token -- 每个 query 应该看到哪些 KV -- 这些 KV 在 paged KV cache 中的索引方式 -- 是否需要树状 mask / causal mask -- MLA kernel 需要的 workspace / split / persistent metadata - -换句话说: - -**speculative decoding 在 runtime 层的本质,就是不断重写 attention metadata。** - - -## 一句话理解 - -可以把 attention metadata 理解为: - -- “这一次 attention 要怎么看 KV cache”的说明书 - -而 speculative decoding 做的事情,本质上就是不断改变这份说明书: - -- normal decode:每个请求 1 个 query,查自己已有上下文 -- draft extend:一次要处理多个 draft token,需要新的 `qo_indptr` 与 mask -- target verify:要同时验证多个候选 token,query 形状和 KV 长度都变了 -- DeepSeek MLA / MTP:`max_q_len`、`kv_indptr`、`qo_indptr`、`work_metadata` - 会直接影响 kernel 如何执行 - - -## 1. 为什么这个问题在 DeepSeek 上特别重要 - -DeepSeek 使用 MLA(Multi-head Latent Attention)后,attention metadata 的作用比普通 -MHA 更重: - -- 普通 MHA 更多是 query/key/value 张量形状和 mask 变化 -- MLA 还要额外构造: - - `kv_indptr` - - `kv_indices` - - `qo_indptr` - - `kv_last_page_len` - - `max_q_len` - - `work_metadata` - - `work_info_set` - - `reduce_indptr` - - `reduce_final_map` - - `reduce_partial_map` - -这些量直接决定: - -- MLA persistent kernel 如何分块 -- 每个 query 要从 paged KV cache 里取哪些 token -- multi-query(例如 verify / MTP)时 query 维度如何展开 - -所以在 DeepSeek speculative 路径中,真正最敏感的往往不是 model forward 本身, -而是 **attention metadata 是否按正确语义构出来**。 - - -## 2. 先看三层 batch 抽象 - -核心文件: - -- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` -- `sglang/python/sglang/srt/managers/schedule_batch.py` - -SGLang 中 batch 数据结构有三层: - -- `ScheduleBatch` -- `ModelWorkerBatch` -- `ForwardBatch` - -源码注释位置: - -- `forward_batch_info.py` 文件开头 - -这三层的职责可以粗略理解为: - -- `ScheduleBatch` - - scheduler 视角 - - 关注请求、prefix、token、调度状态 -- `ModelWorkerBatch` - - worker 视角 - - 关注一次 GPU forward 所需字段 -- `ForwardBatch` - - backend / kernel 视角 - - 关注 query、KV、cache、metadata - -**attention metadata 的最终落点是在 `ForwardBatch -> attn_backend.init_forward_metadata()`** -这一层。 - - -## 3. `ForwardMode`:speculative 如何改变 metadata 初始化分支 - -核心文件: - -- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` - -关键枚举: - -- `ForwardMode` - -关键 speculative mode: - -- `TARGET_VERIFY` -- `DRAFT_EXTEND` -- `DRAFT_EXTEND_V2` - -关键代码位置: - -- `ForwardMode` 定义:约 `74-179` - -最容易踩坑的一点: - -- `ForwardMode.is_extend()` 会把 `TARGET_VERIFY` 也算进去 - -对应逻辑: - -- `forward_batch_info.py` 约 `105-114` - -这意味着如果某个 backend 只是粗暴地区分: - -- decode -- extend - -而没有再细分: - -- target_verify -- draft_extend - -那么它很容易把 verify 当普通 extend 处理,然后在 metadata 上出错。 - - -## 4. speculative 信息是如何进入 attention 层的 - -核心抽象: - -- `SpecInput` - -文件: - -- `sglang/python/sglang/srt/speculative/spec_info.py` - -关键点: - -- `SpecInput` 不是附带信息,而是 speculative 与 attention metadata 的桥梁 -- 它负责携带: - - speculative token 相关信息 - - 需要的 positions - - `kv_indptr` / `kv_indices` - - `custom_mask` - - `accept_length` - - `draft_token_num` - - 其他草稿 / 验证所需状态 - -相关位置: - -- `SpecInputType`:约 `108-113` -- `SpecInput`:约 `116-143` - -这里有一个很重要的方法: - -- `get_spec_adjusted_global_num_tokens()` - -它说明 speculative decoding 会直接改变: - -- global num tokens -- logprob token 数 - -这也间接影响 batch padding 和后续 metadata 构造。 - - -## 5. `EagleVerifyInput`:speculative 到 metadata 的第一层接口 - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_info.py` - -关键类: - -- `EagleVerifyInput` - -关键字段: - -- `draft_token` -- `custom_mask` -- `positions` -- `draft_token_num` -- `capture_hidden_mode` -- `seq_lens_sum` -- `seq_lens_cpu` - -代码位置: - -- `eagle_info.py` 约 `54-78` - -这些字段本身就已经说明了 speculative 和 metadata 的关系: - -- `draft_token` - - 决定 verify 阶段实际送入 target 的 query token -- `positions` - - 决定 RoPE / position indexing -- `draft_token_num` - - 决定一次 verify 需要几个 query -- `custom_mask` - - 决定树状 speculative 验证时的可见性 - - -## 6. verify 阶段:speculative 如何改写 batch - -### 6.1 v1 路径 - -关键文件: - -- `sglang/python/sglang/srt/speculative/eagle_worker.py` -- `sglang/python/sglang/srt/speculative/eagle_info.py` - -关键流程: - -1. `draft()` 先生成候选 token,形成 `EagleVerifyInput` -2. `verify()` 调 `spec_info.prepare_for_verify(batch, page_size)` -3. `batch.forward_mode` 被改成 `TARGET_VERIFY` -4. target worker 执行 verify forward - -关键位置: - -- `eagle_worker.py` 中 `verify()`:约 `699-788` -- `eagle_info.py` 中 `prepare_for_verify()`:约 `104-146` - - -### 6.2 `prepare_for_verify()` 改了什么 - -它主要会做: - -- `batch.input_ids = self.draft_token` -- 分配 `batch.out_cache_loc` -- 更新 `req_to_token_pool` - -也就是: - -- target verify 不再看原先的“普通 decode 单 token 输入” -- 而是把所有 draft token 当作本轮 query 批次 - -这已经说明: - -- verify 不是普通 decode -- verify 的 query 形状和 KV 形状都变了 -- 所以 attention metadata 必须重新构造 - - -## 7. `generate_attn_arg_prefill()`:draft_extend 的 metadata 生成器 - -文件: - -- `sglang/python/sglang/srt/speculative/eagle_info.py` - -关键方法: - -- `generate_attn_arg_prefill()` - -代码位置: - -- 约 `160-216` - -这个函数非常关键,因为它直接把 speculative 信息翻译成 attention metadata 里的核心索引: - -- `qo_indptr` -- `cum_kv_seq_len`(本质上就是 `kv_indptr`) -- `kv_indices` -- `custom_mask` - -可以理解为: - -- speculative 输入先描述“我要验证/扩展多少个 draft token、树结构是什么” -- `generate_attn_arg_prefill()` 再把这种高层语义翻译成 kernel 能消费的索引格式 - - -### 7.1 `qo_indptr` 是什么 - -在这里: - -- `qo_indptr` 表示 query output token 在 batch 中如何分段 - -例如: - -- 每个请求有 `draft_token_num` 个 query -- 那么 `qo_indptr` 就会按这个 query 数量分桶 - - -### 7.2 `kv_indptr` / `cum_kv_seq_len` 是什么 - -它表示: - -- 每个请求在当前 forward 中可见的 KV token 范围 - -draft / verify 会把: - -- 原始 `paged_kernel_lens` - -扩成: - -- `paged_kernel_lens + draft_token_num` - -这说明 speculative decoding 不是只“多几个 query”,而是连本轮可见 KV 长度都变了。 - - -### 7.3 `custom_mask` 是什么 - -对于树状 speculative decode: - -- 不是所有 draft token 都能互相看见 - -所以需要: - -- `custom_mask` - -来表示 tree-based causal structure。 - -这个量会在非 MLA MHA 路径里更直接地进入 attention kernel。 - - -## 8. v2 路径:`prepare_for_v2_verify()` 如何构造 verify metadata - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_info_v2.py` - -关键方法: - -- `prepare_for_v2_verify()` - -代码位置: - -- `eagle_info_v2.py` 约 `213-270` - -这个方法做的事情可以理解成: - -1. 先按 speculative verify 语义设置: - - `batch.input_ids` - - `batch.out_cache_loc` -2. 把 `batch.forward_mode` 改成 `TARGET_VERIFY` -3. 通过 `ForwardBatch.init_new(batch, target_worker.model_runner)` - 得到真正的 `ForwardBatch` -4. 然后显式调用: - - `target_worker.model_runner.attn_backend.init_forward_metadata(verify_forward_batch)` - -这说明: - -- speculative verify 到 attention metadata 的连接点,不是在 model forward 里隐式发生的 -- 而是在 `prepare_for_v2_verify()` 中显式发生的 - - -## 9. attention metadata 长什么样 - -### 9.1 upstream SGLang `ForwardMetadata` - -文件: - -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - -关键 dataclass: - -- `ForwardMetadata` - -代码位置: - -- `aiter_backend.py` 约 `76-95` - -关键字段: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `max_q_len` -- `max_kv_len` -- `work_metadata` -- `work_info_set` -- `reduce_indptr` -- `reduce_final_map` -- `reduce_partial_map` -- `num_kv_splits` -- `custom_mask` -- `mask_indptr` -- `max_extend_len` -- `fp8_prefill_kv_indices` - - -### 9.2 ATOM plugin 的 `ForwardMetadata` - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -关键 dataclass: - -- `ForwardMetadata` - -代码位置: - -- `sgl_attn_backend.py` 约 `171-198` - -从字段上看,ATOM plugin 其实已经承认 speculative / MLA attention 需要这些索引和 workspace。 -所以当前问题不是“不知道这些量存在”,而是: - -- 没在 metadata init 分支上完全按 upstream 语义实现 - - -## 10. upstream `AiterAttnBackend.init_forward_metadata()` 如何按 speculative 分流 - -核心文件: - -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - -关键方法: - -- `init_forward_metadata()` - -代码位置: - -- `aiter_backend.py` 约 `435-684` - -这是整条链最重要的代码之一。 - -它不是简单分成: - -- decode -- extend - -而是分成: - -1. `decode_or_idle` -2. `draft_extend` -3. `target_verify` -4. 普通 extend - - -### 10.1 普通 decode / idle - -逻辑: - -- `spec_info` 为空时,按普通 decode 构造 -- `spec_info` 不为空时,直接复用 `spec_info.kv_indptr / kv_indices` - -这说明 speculative 已经开始介入 decode metadata。 - - -### 10.2 `draft_extend` - -逻辑: - -- 调 `spec_info.generate_attn_arg_prefill()` -- 拿到: - - `kv_indices` - - `kv_indptr` - - `qo_indptr` - - `custom_mask` -- MLA 路再进一步根据 `extend_seq_lens_cpu` - 计算 `max_seqlen_qo` 和 persistent kernel metadata - -关键位置: - -- `aiter_backend.py` 约 `526-606` - - -### 10.3 `target_verify` - -这是最关键的一支。 - -逻辑: - -- 不依赖普通 extend 的 `extend_seq_lens` -- 直接用: - - `draft_num = spec_info.draft_token_num` - - `kv_lens = forward_batch.seq_lens + draft_num` -- 自己构造: - - `qo_indptr` - - `kv_indptr` - - `kv_indices` -- 对 MLA 路: - - `max_q_len = draft_num` - -关键位置: - -- `aiter_backend.py` 约 `607-684` - -这个分支完美说明: - -**verify 不是普通 extend,speculative 会直接重定义 query 长度和 KV 长度。** - - -## 11. DeepSeek MLA:为什么 speculative 更像“metadata 问题” - -对于 DeepSeek MLA 来说,attention forward 真正吃的是: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `max_q_len` -- `work_metadata` / `reduce_*` - -如果这些量不对: - -- 哪怕 model forward、q/k/v 张量本身都没问题 -- kernel 也会在错误的 KV 范围上工作 - -这就是为什么调试 speculative 时,attention metadata 的正确性往往比 model 本身更先决定成败。 - - -## 12. 本次调试得到的一个关键教训 - -在 `ATOM plugin` 当前实现中: - -- `sgl_attn_backend.py` 里的 `_forward_extend_mla()` 已经认识: - - `TARGET_VERIFY` - - `DRAFT_EXTEND` - -代码位置: - -- `sgl_attn_backend.py` 约 `1001-1022` - -但 metadata 初始化层还没有完全按 upstream 分支细化: - -- `init_forward_metadata()` 仍是: - - `decode_or_idle` - - else -> `extend` - -代码位置: - -- `sgl_attn_backend.py` 约 `282-288` - -于是 `TARGET_VERIFY` 会被误送进普通 `_init_extend_mla()`: - -- 它会错误假设 `forward_batch.extend_seq_lens` 一定存在 - -而在 verify 路径下: - -- `extend_seq_lens` 本来就可能是 `None` - -这就是为什么当前错误看起来像: - -- `NoneType has no attribute max` - -实际上本质是: - -- **speculative 和 attention metadata 的语义没有对齐** - - -## 13. 从 ATOM 原生 MTP 再看一次 metadata 的重要性 - -如果看 ATOM 原生链路: - -- `ATOM/atom/spec_decode/eagle.py` -- `ATOM/atom/model_ops/attentions/aiter_mla.py` - -会发现 speculative / MTP 对 attention metadata 的耦合更直接。 - -### 13.1 `EagleProposer.propose()` - -关键位置: - -- `atom/spec_decode/eagle.py` 约 `94-190` - -在多步 draft 过程中,会不断更新: - -- `attn_metadata.max_seqlen_q` -- `attn_metadata.max_seqlen_k` -- `kv_indptr` -- `kv_indices` -- `cu_seqlens_q` -- `slot_mapping` -- `kv_last_page_lens` - -并调用: - -- `prepare_mtp_decode()` - -这说明在 ATOM 原生实现里: - -- speculative 不是 attention 上的一点点附加参数 -- 而是会不断重写 attention metadata - - -### 13.2 `prepare_mtp_decode()` - -文件: - -- `ATOM/atom/model_ops/attentions/aiter_mla.py` - -关键位置: - -- `prepare_mtp_decode()`:约 `225-250` - -作用: - -- 为多 token 预测构造 MTP decode 需要的 KV / worker metadata - -同文件里还有一个重要信号: - -- `prepare_decode()` 会在有 drafter 时把 - `max_seqlen_q = drafter.mtp_k + 1` - -位置: - -- `aiter_mla.py` 约 `352-357` - -这再次说明: - -- speculative / MTP 本质上会改变 query 维度 -- query 维度一变,attention metadata 就必须重建 - - -## 14. 调试 speculative + metadata 时的实用检查表 - -如果后续继续调试 DeepSeek speculative / MTP,建议优先检查下面几项: - -### 1. 当前 `ForwardMode` 是什么 - -看: - -- `decode` -- `target_verify` -- `draft_extend` -- `draft_extend_v2` - -如果 mode 判断错了,metadata 分支通常也会错。 - - -### 2. 当前 `spec_info` 是不是空 - -如果 `spec_info` 不为空,就不应该再走普通 extend 的 metadata 逻辑。 - - -### 3. `qo_indptr` 是否和 speculative token 数一致 - -例如 verify 路径里: - -- `max_q_len` 应该接近 `draft_token_num` - -而不是普通 decode 的 `1`。 - - -### 4. `kv_indptr / kv_indices` 是否按 speculative 后的新 KV 长度构造 - -verify 阶段一般应当看到: - -- `kv_lens = seq_lens + draft_token_num` - -而不是原始 `seq_lens`。 - - -### 5. 是否错误依赖了 `extend_seq_lens` - -普通 extend 可以依赖: - -- `extend_seq_lens` - -但 `target_verify` 不应简单照搬这套假设。 - - -### 6. 是否需要 `custom_mask` - -树状 speculative / topk 路径下: - -- `custom_mask` - -常常是必须的;它缺失时可能不会立刻报错,但结果会错。 - - -## 15. 推荐阅读顺序 - -如果以后要重新从头搞清楚 “DeepSeek speculative 与 attention metadata 的关系”, -推荐按下面顺序阅读: - -1. `sglang/python/sglang/srt/model_executor/forward_batch_info.py` - - 看 `ForwardMode` -2. `sglang/python/sglang/srt/speculative/spec_info.py` - - 看 `SpecInput` -3. `sglang/python/sglang/srt/speculative/eagle_info.py` - - 看 `EagleVerifyInput` -4. `sglang/python/sglang/srt/speculative/eagle_info_v2.py` - - 看 `prepare_for_v2_verify()` -5. `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - - 看 `init_forward_metadata()` 的四种 speculative 分支 -6. `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - - 对照 plugin 当前实现和 upstream 差异 -7. `ATOM/atom/spec_decode/eagle.py` - - 看 ATOM 原生 speculative 是怎么驱动 attn metadata 更新的 -8. `ATOM/atom/model_ops/attentions/aiter_mla.py` - - 看 MTP decode 的 metadata 准备逻辑 - - -## 16. 最终总结 - -对于 DeepSeek 而言: - -- speculative decoding 的重点不只是 draft model -- attention metadata 才是把 speculative 语义真正落到 kernel 的关键层 - -可以用下面一句话概括: - -**draft / verify 负责决定“要处理哪些 token”,attention metadata 负责把这个决定变成 kernel 可执行的 KV / Q 索引和 workspace 说明。** - -因此,后续如果要在 `ATOM + SGLang plugin` 路径真正接通 DeepSeek MTP, -核心工作并不是只把 draft model 换成 `ATOM DeepSeekMTP`,还包括: - -- 让 plugin 的 attention metadata 初始化完整理解 - - `TARGET_VERIFY` - - `DRAFT_EXTEND` - - `MTP 多 query` - - `custom_mask` - - `qo_indptr / kv_indptr / kv_indices` - -只有这层语义也打通,DeepSeek speculative 才算真正可用。 diff --git a/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md b/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md deleted file mode 100644 index b7c170af7..000000000 --- a/work_log/MTP/2026-04-08-sglang-attention-backend-fields-guide.md +++ /dev/null @@ -1,908 +0,0 @@ -# SGLang Attention Backend 字段说明 - -## 文档目的 - -这篇文档专门解释 `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -里 `ForwardMetadata` 的核心字段,重点放在: - -- 这些字段在调度链路里是怎么来的 -- 它们分别表示什么语义 -- 它们的 shape 是什么 -- 它们和 SGLang 的 KV cache 存储结构是什么关系 - -本文刻意**不展开** `reduce_indptr`、`reduce_final_map`、`reduce_partial_map` -这类更偏 kernel 内部 workspace 的字段,只在必要时顺带提一句。 - - -## 一句话理解 - -可以把 `ForwardMetadata` 理解成: - -- `scheduler / ForwardBatch` 已经决定了“这一步要算哪些 request、每个 request 算多少 query、这些 query 应该看到哪些 KV” -- `attn_backend.init_forward_metadata()` 负责把这个高层语义,转换成 attention kernel 真正能消费的低层索引 - -其中最关键的就是三类信息: - -- **Q 侧分段信息**:`qo_indptr`, `max_q_len` -- **KV 侧分段信息**:`kv_indptr`, `kv_indices`, `kv_last_page_len`, `max_kv_len` -- **非 MLA 下的 page 化信息**:`page_table`, `kv_lens` - - -## 1. 三层 batch 抽象 - -SGLang 里和 attention metadata 直接相关的 batch 抽象有三层: - -- `ScheduleBatch` -- `ModelWorkerBatch` -- `ForwardBatch` - -其中: - -- `ScheduleBatch` - - scheduler 视角 - - 关心请求、prefix、seq len、cache slot 分配 -- `ModelWorkerBatch` - - worker 视角 - - 是一次 GPU forward 所需字段的中间态 -- `ForwardBatch` - - attention backend / kernel 视角 - - 大部分字段已经是 GPU tensor - -可以粗略画成: - -```mermaid -flowchart LR - A[Scheduler / ScheduleBatch] - B[ModelWorkerBatch] - C[ForwardBatch] - D[init_forward_metadata] - E[ForwardMetadata] - F[Attention Kernel] - - A --> B --> C --> D --> E --> F -``` - -`ForwardMetadata` 就是 `ForwardBatch` 再往下走一步,把“批次语义”翻译成“索引语义”的结果。 - - -## 2. 调度到 metadata 的主链路 - -最值得记住的链路是: - -1. scheduler 决定这一步 batch 里有哪些 request -2. scheduler 为这些 request 分配或复用 KV slot -3. `ScheduleBatch.get_model_worker_batch()` 把调度状态打包 -4. `ForwardBatch.init_new()` 把 CPU 侧 list / 状态变成 GPU tensor -5. `attn_backend.init_forward_metadata()` 生成 `ForwardMetadata` - -对应几个关键字段来源如下: - -- `req_pool_indices` - - 来自 `ScheduleBatch` - - 表示每个 request 在 `ReqToTokenPool.req_to_token` 里的“行号” -- `seq_lens` - - 每个 request 当前参与 attention 的 KV 长度 -- `out_cache_loc` - - 本轮新 token 写入 KV cache 的物理 slot -- `extend_seq_lens` - - prefill / extend 时,每个 request 本轮真正新增了多少 query token -- `spec_info` - - speculative 路径下额外提供 verify / draft_extend 需要的 query 结构 - - -## 3. 先看 KV cache 的两层存储 - -理解 `kv_indptr` / `kv_indices` 之前,必须先看清 SGLang 的 KV cache 存储不是“一块连续上下文”,而是两层映射: - -- `ReqToTokenPool` -- `TokenToKVPool` - -### 3.1 `ReqToTokenPool` - -文件: - -- `sglang/python/sglang/srt/mem_cache/memory_pool.py` - -核心张量: - -- `req_to_token` - -shape: - -- `[req_pool_size, max_context_len]` -- dtype 通常是 `int32` - -语义: - -- 行:一个 request slot -- 列:这个 request 的逻辑 token 位置 -- 值:该位置对应的 **物理 KV slot id** - -也就是说,`req_to_token` 不是存 K/V 本身,而是存: - -- `request 的第 i 个 token,实际写到了 token_to_kv_pool 的哪个 slot` - -可以理解成: - -```text -req_to_token[req_pool_idx, token_pos] = physical_kv_slot -``` - -### 3.2 `TokenToKVPool` - -它是真正存物理 K/V 的地方。 - -根据注意力形式不同,常见有两类: - -- `MHATokenToKVPool` -- `MLATokenToKVPool` - -### 3.3 MHA KV cache 形状 - -文件: - -- `sglang/python/sglang/srt/mem_cache/memory_pool.py` - -MHA 下,每层通常有两块 buffer: - -- `k_buffer[layer]` -- `v_buffer[layer]` - -shape: - -- `k_buffer[layer]`: `[(size + page_size), num_kv_heads, head_dim]` -- `v_buffer[layer]`: `[(size + page_size), num_kv_heads, v_head_dim]` - -这里的第一维就是 **物理 token slot**。 - -也就是说: - -- `loc = 12345` -- `k_buffer[layer][12345]` -- `v_buffer[layer][12345]` - -就是这个 token 在该层的物理 KV 存储位置。 - -额外的 `+ page_size` 是为了预留 padding / dummy 写入空间。源码里有一句很关键: - -- padded slot 0 用于 padded token 的 dummy output write - -所以它不是严格只分配 `size` 个可见 token 位置,而是多留了一点缓冲。 - -### 3.4 MLA KV cache 形状 - -MLA 下不是单独一块 K、一块 V,而是一个合并后的 latent KV buffer。 - -shape: - -- `kv_buffer[layer]`: `[(size + page_size), 1, kv_cache_dim]` - -其中: - -- `kv_cache_dim = kv_lora_rank + qk_rope_head_dim` - - 对 DeepSeek MLA,通常就是 latent KV 部分加 rope 部分 - -这意味着: - -- MHA:一个 slot 对应 `K` 和 `V` -- MLA:一个 slot 对应一个融合后的 latent cache 向量 - -对 DeepSeek MLA,常见理解方式是: - -- 前半段:`kv_a` / latent KV -- 后半段:`k_pe` / rope 部分 - - -## 4. `req_to_token` 和 `out_cache_loc` 的关系 - -调度器在每轮 forward 前,会先给新 token 分配物理 slot,得到: - -- `out_cache_loc` - -shape: - -- `[num_new_tokens]` - -语义: - -- 本轮所有新增 token 应该写到哪些物理 KV slot - -然后再把这些 slot 填回 `req_to_token` 的对应位置。 - -可以画成: - -```mermaid -flowchart TD - A[request A / req_pool_idx=7] - B[request B / req_pool_idx=9] - C[out_cache_loc = new physical slots] - D[assign_req_to_token_pool] - E[ReqToTokenPool.req_to_token] - F[TokenToKVPool buffers] - - A --> C - B --> C - C --> D --> E - E --> F -``` - -本质上: - -- `out_cache_loc` 决定“新 token 写哪里” -- `req_to_token` 记录“逻辑位置到物理 slot 的长期映射” - - -## 5. `ForwardMetadata` 核心字段总览 - -下面重点解释这些字段: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `max_q_len` -- `max_kv_len` -- `page_table` -- `kv_lens` - -### 5.1 一个总表 - -| 字段 | 常见 shape | 主要用于 | 一句话语义 | -|------|------------|----------|------------| -| `kv_indptr` | `[bs + 1]` | MLA | KV flatten 后每个 request 的段边界 | -| `kv_indices` | `[sum(kv_lens)]` | MLA | flatten 后每个 KV token 对应的物理 slot | -| `qo_indptr` | `[bs + 1]` | MLA / speculative | Q flatten 后每个 request 的段边界 | -| `kv_last_page_len` | `[bs]` | MLA paged kernel | 每个 request 最后一个 page 里有多少有效 token | -| `max_q_len` | `int` | 所有 attention kernel | batch 内单个 request 的最大 query 长度 | -| `max_kv_len` | `int` or `None` | extend / prefill | batch 内单个 request 的最大 KV 长度 | -| `page_table` | `[bs, max_pages]` | 非 MLA | request -> page id 的二维表 | -| `kv_lens` | `[bs]` | 非 MLA | 每个 request 的 KV 长度 | - - -## 6. `kv_indptr` 是什么 - -### 6.1 语义 - -`kv_indptr` 是一个 CSR 风格的前缀和数组。 - -shape: - -- `[bs + 1]` - -语义: - -- 第 `i` 个 request 的 KV 段,在 `kv_indices` 中的范围是: - - `[kv_indptr[i], kv_indptr[i + 1])` - -所以它不是“KV 长度本身”,而是: - -- `flatten 之后每段的起止边界` - -### 6.2 它通常怎么构造 - -典型构造方式: - -```text -kv_indptr[0] = 0 -kv_indptr[1:] = cumsum(kv_lens) -``` - -其中: - -- decode 下,`kv_lens` 往往就是 `seq_lens` -- target_verify 下,MLA 常是 `seq_lens + draft_token_num` -- draft_extend 下,可能来自 speculative 专门生成的 prefill 参数 - -### 6.3 例子 - -假设 batch 里有两个 request: - -- request A 的 KV 长度 = 5 -- request B 的 KV 长度 = 3 - -那么: - -```text -kv_lens = [5, 3] -kv_indptr = [0, 5, 8] -``` - -表示: - -- request A 对应 `kv_indices[0:5]` -- request B 对应 `kv_indices[5:8]` - - -## 7. `kv_indices` 是什么 - -### 7.1 语义 - -`kv_indices` 是一个 flatten 后的一维数组。 - -shape: - -- `[sum(kv_lens)]` - -语义: - -- 它的每个元素都是 **物理 KV slot id** -- 这些 slot id 来自 `req_to_token` - -换句话说: - -- `kv_indices` 是“这一步 attention 真正要访问哪些物理 KV token” - -### 7.2 它和 `req_to_token` 的关系 - -`create_flashinfer_kv_indices_triton(...)` 会根据: - -- `req_pool_indices` -- `req_to_token` -- `kv_lens` -- `kv_indptr` - -把每个 request 对应的那一段 `req_to_token[row, :kv_len]` -抽出来,拼成一个一维的 `kv_indices`。 - -### 7.3 例子 - -假设: - -- `req_pool_indices = [7, 9]` -- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` -- `req_to_token[9, 0:3] = [200, 201, 220]` - -那么: - -```text -kv_indptr = [0, 5, 8] -kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] -``` - -这就表示: - -- request A 的 attention 访问物理 slot `100,101,102,103,120` -- request B 的 attention 访问物理 slot `200,201,220` - -可以把它理解为: - -```mermaid -flowchart LR - A["req_to_token[row=7] = [100,101,102,103,120,...]"] - B["req_to_token[row=9] = [200,201,220,...]"] - C["kv_indices = [100,101,102,103,120,200,201,220]"] - - A --> C - B --> C -``` - - -## 8. `qo_indptr` 是什么 - -### 8.1 语义 - -`qo_indptr` 和 `kv_indptr` 是对称的,但它描述的是 **Q / output 侧**。 - -shape: - -- `[bs + 1]` - -语义: - -- 第 `i` 个 request 的 query 段,在 flatten 后的 Q 张量中的范围是: - - `[qo_indptr[i], qo_indptr[i + 1])` - -### 8.2 为什么它很重要 - -attention backend 经常把 batch 里的 query token flatten 成一个二维/三维张量去跑 kernel。 - -这时 kernel 需要知道: - -- 哪些 query 属于 request A -- 哪些 query 属于 request B - -`qo_indptr` 就是这份分段说明书。 - -### 8.3 不同模式下的典型含义 - -- decode - - 每个 request 只有 1 个 query - - 所以常见是 `[0, 1, 2, ..., bs]` -- 普通 extend / prefill - - 每个 request 的 query 数就是 `extend_seq_lens[i]` - - 所以通常是 `cumsum(extend_seq_lens)` -- target_verify - - 每个 request 通常有 `draft_token_num` 个 query - - 所以常是 `[0, d, 2d, 3d, ...]` -- draft_extend - - 每个 request 的 query 数可能不同 - - 常来自 `accept_length` 或 `extend_seq_lens` - -### 8.4 例子 - -假设有两个 request: - -- request A 本轮新增 query = 3 -- request B 本轮新增 query = 2 - -那么: - -```text -extend_seq_lens = [3, 2] -qo_indptr = [0, 3, 5] -``` - -表示: - -- request A 的 query 是 flatten Q 中的 `[0:3]` -- request B 的 query 是 flatten Q 中的 `[3:5]` - - -## 9. `kv_last_page_len` 是什么 - -### 9.1 语义 - -这是分页 KV cache 下很重要的一个辅助量。 - -shape: - -- `[bs]` - -语义: - -- 每个 request 的最后一个 page 里,有多少个有效 token - -因为 paged KV cache 不是要求每个 request 的长度都刚好是 `page_size` 的整数倍,所以最后一个 page 往往只有一部分有效。 - -### 9.2 例子 - -假设 `page_size = 4`: - -- request A 的 KV 长度 = 5 -- request B 的 KV 长度 = 3 - -那么: - -- request A 有 2 个 page,最后一个 page 有 1 个有效 token -- request B 有 1 个 page,最后一个 page 有 3 个有效 token - -对应: - -```text -kv_last_page_len = [1, 3] -``` - - -## 10. `max_q_len` 和 `max_kv_len` - -### 10.1 `max_q_len` - -语义: - -- batch 内单个 request 的最大 query 长度 - -常见来源: - -- decode: `1` -- 普通 extend: `max(extend_seq_lens)` -- target_verify: `draft_token_num` -- draft_extend: 常是 `max(extend_seq_lens)` 或 `max(accept_length)` - -shape: - -- Python `int` - -作用: - -- kernel 需要知道 batch 内最大 query 段长度,来决定 tile / workspace / pad 方式 - -### 10.2 `max_kv_len` - -语义: - -- batch 内单个 request 的最大 KV 长度 - -常见来源: - -- 普通 extend / prefill:`max(seq_lens)` -- 某些 decode / verify MLA 路径里可能不单独存,或者设成 `None` - -shape: - -- Python `int` 或 `None` - - -## 11. `page_table` 和 `kv_lens` - -这两个字段更偏 **非 MLA** 路径,是 `kv_indptr/kv_indices` 的 page 化替代表示。 - -### 11.1 `page_table` - -shape: - -- `[bs, max_num_pages_per_request]` - -语义: - -- 每一行对应一个 request -- 每个元素是一个 physical page id - -它不是 token-level 的 flatten 索引,而是 page-level 的二维映射。 - -### 11.2 `kv_lens` - -shape: - -- `[bs]` - -语义: - -- 每个 request 当前 KV 长度 - -kernel 会结合: - -- `page_table` -- `kv_lens` -- `page_size` - -来知道每个 request 该读哪些 page、最后一页有多少有效 token。 - -### 11.3 这组字段是不是只给 MLA 用 - -不是。 - -`ForwardMetadata` 更准确地说是一个: - -- **统一容器** - -它里面同时装了: - -- MLA 常用的 metadata -- MHA 常用的 metadata -- 两边都可能用到的通用字段 - -可以粗略分成三类: - -| 类别 | 字段 | -|------|------| -| 更偏 MLA | `kv_indptr`, `kv_indices`, `qo_indptr`, `kv_last_page_len` | -| 更偏 MHA | `page_table`, `kv_lens`, `pa_metadata_*` | -| 通用 | `max_q_len`, `max_kv_len` | - -也就是说: - -- **不是只有 MLA 才会创建 `ForwardMetadata`** -- 而是 **MLA 和 MHA 共用这个 dataclass** -- 只是不同 kernel 最终只消费其中的一部分字段 - -### 11.4 MHA 的 metadata 代码在哪里 - -如果想看 `sgl_attn_backend.py` 里 **MHA 真正的 metadata 路径**,主要看这几段: - -- `_init_decode_mha()` -- `_init_extend_mha()` -- `_build_pa_metadata_for_decode()` -- `_build_pa_metadata_for_prefill()` - -含义可以概括成: - -- `decode` - - 优先看 `page_table`, `kv_lens` - - 如果走 `pa_persistent_fwd`,再看 `pa_metadata_*` -- `extend / prefill` - - 主要看 `max_q_len`, `max_kv_len` - - page 化路径下也会继续依赖 `page_table`, `kv_lens` - -换句话说: - -- **MLA 更像 “token-level flattened 索引驱动”** -- **MHA 更像 “page-table / context-len 驱动”** - -### 11.5 为什么 MHA 通常不需要 `kv_last_page_len` - -这个问题最容易和 MLA 搞混。 - -核心原因是: - -- MHA 在这个 backend 里通常走的是: - - `page_table + kv_lens` - - 或者 `pa_metadata_*` -- MLA 则更依赖: - - `kv_indptr + kv_indices + kv_last_page_len` - -对 MHA 来说,kernel 经常直接拿到: - -- 每个 request 的 `context_lens` -- 每个 request 对应哪些 page(`page_table`) -- 固定的 `page_size` - -于是: - -- 最后一页有多少有效 token -- 可以由 `context_lens % page_size` -- 或更高层 page metadata 直接推出来 - -所以 MHA 不一定需要把: - -- “最后一个 page 的有效长度” - -单独存成 `kv_last_page_len`。 - -而 MLA 的 paged kernel 实现更偏 token-flatten / ragged 索引驱动,显式传: - -- `kv_last_page_len` - -会更直接、更方便。 - -### 11.6 为什么 MHA 通常不需要 `qo_indptr` - -`qo_indptr` 的本质是: - -- flatten 后 query 段的边界表 - -它在 MLA 里很重要,因为 MLA kernel 经常直接消费: - -- ragged / flatten 的 Q 段 -- 对应的 KV flatten 段 - -而 MHA 在这个 plugin 里常见有两类路径: - -#### 路径一:decode 的 `pa_fwd_asm` / `pa_persistent_fwd` - -这类 kernel 更偏: - -- `block_tables = page_table` -- `context_lens = kv_lens` - -decode 下每个 request 本来就只有 1 个 query,所以 query 分段是隐含的: - -- batch 第 0 个 query 属于 request 0 -- batch 第 1 个 query 属于 request 1 - -这时单独维护 `qo_indptr` 不是必须的。 - -#### 路径二:extend 的 `flash_attn_varlen_func` - -这条路在当前 plugin 里更依赖: - -- 显式传入的 `q`, `k`, `v` -- `max_q_len`, `max_kv_len` -- 以及运行时构出来的 `cu_seqlens_q` - -这里 query 的分段信息已经由: - -- 输入张量本身 -- `cu_seqlens_q` - -表达出来了,所以 `qo_indptr` 也不是核心字段。 - -因此可以把它理解成: - -- **MLA 喜欢把 Q 段边界显式放进 metadata** -- **MHA 更常把 Q 段边界隐含在输入张量和专用 kernel 参数里** - -### 11.7 为什么 MHA 通常不需要 `kv_indptr + kv_indices` - -`kv_indptr + kv_indices` 的组合,本质上是在表达: - -- “把所有 request 的 KV token 拉平成一条长数组之后,每个 request 的 KV 段从哪里开始,到哪里结束” - -这是一种非常适合: - -- ragged token-level attention -- MLA flatten KV 访问 - -的表示法。 - -但 MHA 在 paged cache 下经常不需要把 KV 先 flatten 成 token 列表。 - -因为它可以直接用: - -- `page_table` -- `kv_lens` - -来表达同一件事: - -- 第 `i` 个 request 对应哪些 page -- 这些 page 中实际有多少 token 是有效的 - -可以类比为: - -- `kv_indptr + kv_indices` - - 是 token-level 的“稀疏展开形式” -- `page_table + kv_lens` - - 是 page-level 的“块索引形式” - -两者本质都在回答: - -- “本轮 attention 应该读哪些 KV” - -只是表达层次不同。 - -所以不是说 MHA **完全不能** 用 `kv_indptr + kv_indices`, -而是: - -- 在当前 backend 的主实现里,MHA 的更自然表示是 `page_table + kv_lens` -- `kv_indptr + kv_indices` 在 MHA 路径里通常不是主角 - -### 11.8 一个简化对照表 - -| 场景 | 更核心的 metadata | -|------|-------------------| -| MLA decode | `kv_indptr`, `kv_indices`, `qo_indptr`, `kv_last_page_len` | -| MLA extend | `kv_indptr`, `kv_indices`, `qo_indptr`, `max_q_len`, `max_kv_len` | -| MHA decode | `page_table`, `kv_lens`, `pa_metadata_*` | -| MHA extend | `max_q_len`, `max_kv_len`,以及必要时的 `page_table`, `kv_lens` | - - -## 12. 三个最重要的例子 - -### 12.1 例子一:普通 decode - -假设: - -- `bs = 2` -- `seq_lens = [5, 3]` -- `req_pool_indices = [7, 9]` -- `page_size = 4` - -并且: - -- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` -- `req_to_token[9, 0:3] = [200, 201, 220]` - -那么: - -```text -kv_indptr = [0, 5, 8] -kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] -qo_indptr = [0, 1, 2] -kv_last_page_len= [1, 3] -max_q_len = 1 -``` - -含义: - -- 每个 request 只解一个 token -- query 一共 2 个 -- 但每个 query 需要看到自己已有的完整上下文 KV - -### 12.2 例子二:普通 extend / prefill - -假设: - -- request A:prefix 长度 5,本轮 extend 3 个 token,总长 8 -- request B:prefix 长度 7,本轮 extend 2 个 token,总长 9 - -则: - -```text -extend_prefix_lens = [5, 7] -extend_seq_lens = [3, 2] -seq_lens = [8, 9] -qo_indptr = [0, 3, 5] -max_q_len = 3 -max_kv_len = 9 -``` - -理解方式: - -- Q 侧看的是“本轮新增的 3/2 个 token” -- KV 侧看的是“整个请求当前总长度 8/9” - -也就是说: - -- `qo_indptr` 由本轮新增 query 决定 -- `kv_indptr / max_kv_len` 由总上下文长度决定 - -### 12.3 例子三:`TARGET_VERIFY` - -假设: - -- `bs = 2` -- `seq_lens = [5, 3]` -- `draft_token_num = 4` - -那么对每个 request: - -- 本轮要验证 4 个 draft token -- 但每个 query 能看到的 KV 长度不是原始 `seq_lens` -- 而是 `seq_lens + draft_token_num` - -于是: - -```text -qo_indptr = [0, 4, 8] -kv_lens = [9, 7] -kv_indptr = [0, 9, 16] -max_q_len = 4 -``` - -这个例子非常重要,因为它说明: - -- verify 不是普通 decode -- 也不是普通 extend -- 它会同时改变 Q 的分段和 KV 的可见长度 - -这也是为什么 speculative path 不能简单复用普通 extend metadata。 - - -## 13. 一个完整的“逻辑位置 -> 物理 KV”例子 - -假设 request A 当前已经有 5 个 token: - -```text -req_pool_idx = 7 -req_to_token[7, 0:5] = [100, 101, 102, 103, 120] -``` - -这表示: - -- 逻辑 token 0 -> physical slot 100 -- 逻辑 token 1 -> physical slot 101 -- 逻辑 token 2 -> physical slot 102 -- 逻辑 token 3 -> physical slot 103 -- 逻辑 token 4 -> physical slot 120 - -如果本轮又新分配了两个 slot: - -```text -out_cache_loc = [130, 131] -``` - -并把它们写回 request A 的逻辑位置 5、6: - -```text -req_to_token[7, 5] = 130 -req_to_token[7, 6] = 131 -``` - -那么 request A 的完整逻辑到物理映射就变成: - -```text -[100, 101, 102, 103, 120, 130, 131] -``` - -之后 attention metadata 只要知道: - -- `req_pool_idx = 7` -- `kv_len = 7` - -就能通过 `req_to_token` 自动构造出: - -```text -kv_indices = [100, 101, 102, 103, 120, 130, 131] -``` - - -## 14. 可以如何快速判断一个字段该不该看 - -如果你在 debug `sgl_attn_backend.py`,可以用下面这个经验法则: - -- 想知道“这轮每个 request 有多少 query” - - 看 `qo_indptr`, `max_q_len`, `extend_seq_lens` -- 想知道“这轮每个 request 的 KV 能看到多长” - - 看 `kv_indptr`, `kv_last_page_len`, `max_kv_len`, `kv_lens` -- 想知道“这些 KV 实际在 cache 里是哪几个 slot” - - 看 `kv_indices` -- 想知道“request 的逻辑 token 位置和物理 slot 怎么对应” - - 看 `req_to_token_pool.req_to_token` -- 想知道“本轮新 token 会写去哪里” - - 看 `out_cache_loc` - - -## 15. 最后总结 - -如果只记三句话: - -1. `req_to_token` 是 **逻辑 token 位置 -> 物理 KV slot** 的长期映射表。 -2. `kv_indptr + kv_indices` 是把这张长期映射表裁成“本轮 attention 真正要访问的 KV 列表”。 -3. `qo_indptr` 是 query 侧的分段表,告诉 kernel flatten 后哪些 query 属于哪个 request。 - -所以: - -- `scheduler` 决定 batch 语义 -- `req_to_token / out_cache_loc` 决定 cache 物理布局 -- `ForwardMetadata` 把二者翻译成 kernel 真正能消费的索引 - -这就是 `sgl_attn_backend.py` 里这些字段的核心意义。 diff --git a/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md b/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md deleted file mode 100644 index 3b7bb3978..000000000 --- a/work_log/MTP/2026-04-08-sglang-kv-cache-storage-guide.md +++ /dev/null @@ -1,692 +0,0 @@ -# SGLang KV Cache Storage Guide - -## 文档目的 - -这篇文档专门解释: - -- SGLang 里 KV cache 是怎么存的 -- MHA 和 MLA 的存储结构有什么区别 -- `ReqToTokenPool`、allocator、`TokenToKVPool` 分别负责什么 -- 常见张量 shape 是什么 -- 逻辑 token 是怎么映射到物理 KV slot 的 - -本文聚焦 **SGLang 本身的 KV cache 存储**,不展开 vLLM 或 ATOM native 的实现差异。 - - -## 一句话理解 - -SGLang 的 KV cache 不是“每个 request 一块连续显存”。 - -它更像一个两级系统: - -1. `ReqToTokenPool` - - 维护逻辑映射 - - 回答“这个 request 的第 t 个 token 存在哪个 slot” -2. `TokenToKVPool` - - 维护物理存储 - - 真正把 K/V 或 MLA latent 写进 GPU buffer - -可以先记住这个核心公式: - -```text -req_to_token[req_pool_idx, token_pos] = physical_slot -``` - - -## 1. 先统一几个概念 - -### 1.1 request slot - -SGLang 不直接拿 request id 当数组下标。 - -它会先从 `ReqToTokenPool` 给每个活跃 request 分一个: - -- `req_pool_idx` - -这就是这个 request 在 `req_to_token` 表里的“行号”。 - -### 1.2 token slot - -一个 token slot 表示: - -- 某一层 KV cache 里,一个 token 对应的一行物理存储位置 - -它通常是一个全局整数,例如: - -- `100` -- `101` -- `2025` - -### 1.3 page - -当 `page_size > 1` 时,slot 会按 page 分组。 - -可以理解成: - -```text -一个 page = page_size 个连续 token slot -``` - -于是: - -```text -slot = page_id * page_size + page_offset -``` - -### 1.4 `out_cache_loc` - -`out_cache_loc` 是本轮新分配出来的物理 slot 列表。 - -shape: - -- extend / prefill:`[extend_num_tokens]` -- decode:通常是 `[bs * token_per_req]` - -它表示: - -- 本轮新 token 应该写到 KV cache 的哪些物理位置 - - -## 2. 总体架构 - -可以把 SGLang 的 KV cache 存储看成三层: - -- request 级逻辑层 -- token/page 分配层 -- 物理存储层 - -```mermaid -flowchart LR - A[Request] - B[req_pool_idx] - C[ReqToTokenPool.req_to_token] - D[out_cache_loc / allocator] - E[TokenToKVPool] - F[Physical KV buffers] - - A --> B - B --> C - D --> C - D --> E --> F - C --> F -``` - -更具体一点: - -- `ReqToTokenPool` - - 存“逻辑 token 位置 -> 物理 slot” -- allocator - - 决定这一步还能分到哪些新 slot -- `TokenToKVPool` - - 存每层真实的 K/V 或 MLA latent - - -## 3. 第一层:`ReqToTokenPool` - -文件: - -- `sglang/python/sglang/srt/mem_cache/memory_pool.py` - -核心类: - -- `ReqToTokenPool` - -核心张量: - -- `req_to_token` - -shape: - -- `[req_pool_size, max_context_len]` - -dtype: - -- `int32` - -语义: - -- 第 0 维:request slot,也就是 `req_pool_idx` -- 第 1 维:该 request 的逻辑 token 位置 -- 元素值:这个逻辑 token 对应的物理 KV slot - -也就是: - -```text -req_to_token[req_pool_idx, token_pos] = physical_slot -``` - -### 3.1 例子 - -假设 request A 被分到: - -- `req_pool_idx = 7` - -并且当前已经有 5 个 token: - -```text -req_to_token[7, 0:5] = [100, 101, 102, 103, 120] -``` - -那么它的逻辑到物理映射就是: - -- token 0 -> slot 100 -- token 1 -> slot 101 -- token 2 -> slot 102 -- token 3 -> slot 103 -- token 4 -> slot 120 - -注意: - -- 这里的 slot 不要求连续 -- 因为分页分配、复用、evict 都可能让物理位置不连续 - - -## 4. 第二层:allocator - -allocator 负责: - -- 从可用的 KV 空间中分配新的 slot 或 page -- 返回 `out_cache_loc` -- 再把结果写回 `req_to_token` - -SGLang 里常见有两类 allocator: - -- `TokenToKVPoolAllocator` - - `page_size = 1` - - 更像 token 粒度的平铺分配 -- `PagedTokenToKVPoolAllocator` - - `page_size > 1` - - 更像 page 粒度分配 - -相关文件: - -- `sglang/python/sglang/srt/mem_cache/allocator.py` -- `sglang/python/sglang/srt/mem_cache/common.py` - -### 4.1 `page_size = 1` - -这时 allocator 的视角非常简单: - -- 一个 free slot 就是一个 free token position - -分配出来的 `out_cache_loc` 可以直接看成: - -- 一串 token slot id - -源码里还有一个关键细节: - -- slot `0` 被保留给 padded token / dummy write - -所以真正可分配的 slot 常从 `1` 开始。 - -### 4.2 `page_size > 1` - -这时 allocator 虽然内部按 page 管理, -但对上层仍然返回: - -- token-level 的 `out_cache_loc` - -也就是说,上层最终看到的还是: - -- 这一步每个新 token 具体写到哪个 slot - -只是这些 slot 是由 page allocator 算出来的。 - -### 4.3 extend 时 allocator 做了什么 - -`alloc_for_extend()` 的语义可以概括成: - -1. 先给 request 分配 `req_pool_idx` -2. 再根据 prefix 长度和目标 seq 长度,分配这一步新增 token 的物理 slot -3. 生成 `out_cache_loc` -4. 把这些新 slot 写回 `req_to_token` - -所以: - -- `out_cache_loc` 是“这一步新 token 的写入位置” -- `req_to_token` 是“整个 request 的长期索引表” - -### 4.4 decode 时 allocator 做了什么 - -decode 最常见是每个 request 增加 1 个 token。 - -这时: - -- allocator 为每个 request 分 1 个新 slot -- `out_cache_loc` 的长度通常就是 batch size -- 然后把这个新 slot 写到 `req_to_token[req_pool_idx, 当前 seq_len]` - - -## 5. 第三层:物理 KV 存储 - -这层才是真正的大显存 buffer。 - -SGLang 里和 attention 相关的主要有: - -- `MHATokenToKVPool` -- `MLATokenToKVPool` - -两者最大的区别在于: - -- MHA 存 K 和 V 两份 buffer -- MLA 存一份 packed latent buffer - - -## 6. MHA 的 KV cache 存储 - -文件: - -- `sglang/python/sglang/srt/mem_cache/memory_pool.py` - -核心类: - -- `MHATokenToKVPool` - -### 6.1 核心 buffer - -每层有两份物理 buffer: - -- `k_buffer[layer]` -- `v_buffer[layer]` - -shape: - -- `k_buffer[layer]`: `[(size + page_size), num_kv_heads, head_dim]` -- `v_buffer[layer]`: `[(size + page_size), num_kv_heads, v_head_dim]` - -这里: - -- 第 0 维是物理 slot -- 第 1 维是 KV heads -- 第 2 维是每个 head 的维度 - -`size + page_size` 的原因是: - -- 除了正常容量,还预留了 padding / dummy 写入空间 - -### 6.2 怎么写入 - -写入接口通常是: - -- `set_kv_buffer(layer, loc, cache_k, cache_v, ...)` - -其中: - -- `loc.shape = [num_tokens]` -- `cache_k.shape = [num_tokens, num_kv_heads, head_dim]` -- `cache_v.shape = [num_tokens, num_kv_heads, v_head_dim]` - -语义就是: - -```text -k_buffer[layer][loc[i]] = cache_k[i] -v_buffer[layer][loc[i]] = cache_v[i] -``` - -### 6.3 怎么读 - -读取接口通常是: - -- `get_key_buffer(layer_id)` -- `get_value_buffer(layer_id)` -- `get_kv_buffer(layer_id)` - -attention backend 会根据 `req_to_token` 算出的 slot, -去这些 buffer 里 gather 对应位置。 - - -## 7. MLA 的 KV cache 存储 - -文件: - -- `sglang/python/sglang/srt/mem_cache/memory_pool.py` - -核心类: - -- `MLATokenToKVPool` - -### 7.1 核心 buffer - -MLA 下,每层通常只有一份主 buffer: - -- `kv_buffer[layer]` - -shape: - -- `[(size + page_size), 1, kv_cache_dim]` - -其中: - -```text -kv_cache_dim = kv_lora_rank + qk_rope_head_dim -``` - -这表示: - -- 每个物理 slot 存的是一段 packed latent KV -- 不是标准 MHA 意义上的分离 K / V - -### 7.2 逻辑拆分 - -对于 DeepSeek MLA,通常可以把这段 packed buffer 理解成: - -- 前半段:`kv_a` / latent KV -- 后半段:`k_pe` / rope 相关部分 - -也就是说,一个 slot 里实际上装的是: - -```text -[cache_k_nope | cache_k_rope] -``` - -### 7.3 写入接口 - -MLA 常见有两种写法: - -- `set_kv_buffer(...)` - - 直接把 packed cache 写进去 -- `set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)` - - 分别传 latent 部分和 rope 部分,由底层 helper 拼到一起 - -典型输入 shape: - -- `loc`: `[num_tokens]` -- `cache_k_nope`: `[num_tokens, 1, kv_lora_rank]` -- `cache_k_rope`: `[num_tokens, 1, qk_rope_head_dim]` - -### 7.4 读取接口 - -MLA 也有专门的读取接口: - -- `get_mla_kv_buffer(layer, loc, dst_dtype)` - -返回: - -- `cache_k_nope`: `[num_tokens, 1, kv_lora_rank]` -- `cache_k_rope`: `[num_tokens, 1, qk_rope_head_dim]` - -所以 MLA 的“读出来再用”其实也是把 packed storage 再拆回两部分。 - - -## 8. MHA 和 MLA shape 对照表 - -| 项目 | MHA | MLA | -|------|-----|-----| -| 主 buffer 数量 | 2 份:K / V | 1 份:packed latent | -| 每层物理 shape | K:`[slots, Hkv, Dk]` V:`[slots, Hkv, Dv]` | `[slots, 1, kv_lora_rank + qk_rope_head_dim]` | -| 第 0 维含义 | 物理 token slot | 物理 token slot | -| 典型写入接口 | `set_kv_buffer()` | `set_mla_kv_buffer()` | -| 典型读取接口 | `get_kv_buffer()` | `get_mla_kv_buffer()` | -| 逻辑视角 | 标准 K/V cache | latent KV + rope 部分 | - - -## 9. `out_cache_loc`、`req_to_token`、buffer 的关系 - -可以把一次写入过程画成: - -```mermaid -flowchart TD - A[allocator 分配新 slot] - B[out_cache_loc] - C[写回 req_to_token] - D[attention forward 产出 K/V 或 MLA latent] - E[set_kv_buffer / set_mla_kv_buffer] - F[物理 KV buffer] - - A --> B - B --> C - B --> E - D --> E --> F -``` - -这里有两个并行动作: - -- `out_cache_loc` 被写回 `req_to_token` -- 同时新算出来的 KV 被写进物理 buffer - -这样下一轮只要知道: - -- `req_pool_idx` -- 当前 `seq_len` - -就能通过 `req_to_token` 找到历史 token 对应的所有物理 slot。 - - -## 10. 例子一:非分页 MHA decode - -假设: - -- `page_size = 1` -- batch 有 2 个 request -- `req_pool_indices = [7, 9]` -- 当前 `seq_lens = [5, 3]` - -已有映射: - -```text -req_to_token[7, 0:5] = [100, 101, 102, 103, 120] -req_to_token[9, 0:3] = [200, 201, 220] -``` - -本轮 decode,每个 request 新增 1 个 token,allocator 返回: - -```text -out_cache_loc = [130, 221] -``` - -然后写回: - -```text -req_to_token[7, 5] = 130 -req_to_token[9, 3] = 221 -``` - -于是下一轮: - -- request A 的完整上下文 slot 是 `[100,101,102,103,120,130]` -- request B 的完整上下文 slot 是 `[200,201,220,221]` - -物理存储上则是: - -```text -k_buffer[layer][130] = new_k_for_A -v_buffer[layer][130] = new_v_for_A - -k_buffer[layer][221] = new_k_for_B -v_buffer[layer][221] = new_v_for_B -``` - - -## 11. 例子二:分页 MHA extend - -假设: - -- `page_size = 4` -- request A 的 prefix 长度 = 5 -- 本轮 extend 后总长度 = 8 - -也就是: - -- prefix token 已经占了 5 个逻辑位置 -- 本轮要再写 3 个 token - -假设它当前最后一个已用 slot 是: - -```text -last_loc = 120 -``` - -而这个 `120` 恰好在某个 page 的中间。 - -那么 allocator 在 `alloc_paged_token_slots_extend()` 里大概会做两件事: - -1. 先尽量把当前未满的最后一个 page 填满 -2. 如果还不够,再分配新 page - -可能得到: - -```text -out_cache_loc = [121, 122, 200] -``` - -这表示: - -- 前两个 token 继续写进原 page 的剩余位置 -- 第三个 token 写进新 page 的第一个 slot - -然后写回: - -```text -req_to_token[7, 5:8] = [121, 122, 200] -``` - -所以分页 allocator 的重点不是“返回 page id”,而是: - -- **最终依然返回 token-level slot ids** - - -## 12. 例子三:MLA 写入和读取 - -假设: - -- `kv_lora_rank = 512` -- `qk_rope_head_dim = 64` - -那么: - -```text -kv_cache_dim = 576 -``` - -对某层来说,MLA 物理 buffer 的 shape 可能是: - -```text -kv_buffer[layer].shape = [num_slots, 1, 576] -``` - -本轮有 2 个新 token: - -```text -loc = [130, 131] -cache_k_nope.shape = [2, 1, 512] -cache_k_rope.shape = [2, 1, 64] -``` - -调用: - -```text -set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope) -``` - -后,可以理解为: - -```text -kv_buffer[layer][130] = concat(cache_k_nope[0], cache_k_rope[0]) -kv_buffer[layer][131] = concat(cache_k_nope[1], cache_k_rope[1]) -``` - -后续 attention 需要读取历史 cache 时,再通过: - -```text -get_mla_kv_buffer(layer, loc=[100,101,130], dst_dtype=bf16) -``` - -拿回: - -- `cache_k_nope`: `[3, 1, 512]` -- `cache_k_rope`: `[3, 1, 64]` - - -## 13. 为什么 SGLang 要搞两层,而不是直接 request -> kv buffer - -因为推理服务不是静态 batch。 - -SGLang 的 request 会不断: - -- 加入 -- 完成 -- 被截断 -- 被 speculative verify / draft_extend 修改 -- 被分页 allocator 扩容 - -如果直接给每个 request 一块连续大 buffer: - -- 复用差 -- 容易碎片化 -- prefix cache / page cache 不好做 - -两层结构的好处是: - -- `ReqToTokenPool` - - 负责逻辑组织 -- `TokenToKVPool` - - 负责物理存储 - -这样: - -- request 的逻辑顺序可以变 -- 物理 slot 可以复用 -- page allocator 可以独立演化 -- MHA / MLA 只需要换底层 KV pool 的 shape,不用重写上层 request 索引系统 - - -## 14. 和 attention metadata 的关系 - -KV cache 存储本身只回答: - -- 数据放在哪里 - -attention metadata 还要回答: - -- 本轮到底读哪些 token -- 这些 token 该怎么分段 -- 对应哪个 request - -所以常见链路是: - -1. `req_to_token` - - 保存 request 的长期逻辑到物理映射 -2. `out_cache_loc` - - 保存本轮新 token 的新物理位置 -3. attention metadata - - 从 `req_to_token` 中抽出本轮真正要访问的那部分 slot - - 形成 `kv_indices` 或 `page_table` - -也就是说: - -- KV cache storage 是“数据库” -- attention metadata 是“查询结果” - - -## 15. 调试时最该先看什么 - -如果你在 debug SGLang KV cache,建议按这个顺序看: - -1. `req_pool_idx` - - 这个 request 映射到哪一行 -2. `req_to_token[row, :seq_len]` - - 当前逻辑 token 对应哪些物理 slot -3. `out_cache_loc` - - 本轮新 token 写到哪里 -4. `k_buffer / v_buffer` 或 `kv_buffer` - - 这些 slot 位置上实际存了什么 shape -5. attention metadata - - 例如 `kv_indices` / `page_table` - - 看本轮真正读的是不是你以为的那些 slot - - -## 16. 最后总结 - -如果只记 6 句话: - -1. `ReqToTokenPool.req_to_token` 是 SGLang KV cache 的逻辑索引总表。 -2. `out_cache_loc` 是本轮新 token 的物理写入位置。 -3. allocator 可能按 token 或 page 分配,但返回给上层的通常仍是 token-level slot。 -4. MHA 物理存储是两份: - - `k_buffer[layer]: [slots, Hkv, Dk]` - - `v_buffer[layer]: [slots, Hkv, Dv]` -5. MLA 物理存储是一份 packed buffer: - - `kv_buffer[layer]: [slots, 1, kv_lora_rank + qk_rope_head_dim]` -6. attention metadata 不是重复存储 KV cache,而是基于 `req_to_token` 再生成“本轮实际访问哪些 KV”的索引视图。 - -这就是 SGLang 里 KV cache 存储的核心结构。 diff --git a/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md b/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md deleted file mode 100644 index 26db38442..000000000 --- a/work_log/MTP/2026-04-08-sglang-speculative-decoding-architecture.md +++ /dev/null @@ -1,910 +0,0 @@ -# 2026-04-08 SGLang Speculative Decoding 架构笔记 - -## 文档目的 - -这份文档用于从 **SGLang 整体架构** 的角度梳理 speculative decoding -(推测解码)的实现方式,重点回答下面几个问题: - -- SGLang 在启动阶段如何决定是否进入 speculative decoding -- target model 和 draft model 是如何被构造和组织的 -- scheduler、worker、batch、attention backend 分别承担什么职责 -- EAGLE / EAGLE3 / NEXTN / STANDALONE / NGRAM 在 SGLang 里是怎样映射的 -- v1 与 v2(overlap/spec v2)在控制流和数据结构上有何差异 -- 遇到问题时,应该优先看哪些代码位置 - -本文会尽量从“系统设计”和“源码位置”两个维度同时展开,方便后续复盘。 - - -## 一句话总结 - -从架构上看,SGLang 的 speculative decoding 可以概括成: - -- **配置层** 先在 `ServerArgs` 中解析算法与 draft 参数 -- **模型配置层** 决定 draft model 应该映射成哪一个架构名 -- **调度层** 通过 `Scheduler` 把执行入口从普通 `TpModelWorker` 切到 - speculative orchestrator -- **worker 层** 维护 target worker 与 draft worker 的协作关系 -- **batch 层** 用 `ScheduleBatch -> ModelWorkerBatch -> ForwardBatch` 三层结构 - 将调度语义转为 GPU 执行语义 -- **attention/backend 层** 再按 `ForwardMode` 和 `SpecInput` 区分 - `decode / target_verify / draft_extend` - -也就是说,speculative decoding 在 SGLang 里不是“加一个 draft model”这么简单, -而是一整套跨: - -- 配置 -- 调度 -- 模型加载 -- batch 编排 -- attention metadata - -的系统设计。 - - -## 1. 启动入口:ServerArgs 如何决定 speculative decoding - -### 1.1 关键配置项 - -核心入口文件: - -- `sglang/python/sglang/srt/server_args.py` - -关键字段位置: - -- `speculative_algorithm`:约 `480` -- `speculative_draft_model_path`:约 `481` -- `speculative_num_steps`:约 `484` -- `speculative_num_draft_tokens`:约 `486` - -这些字段决定: - -- 用哪种 speculative 算法 -- draft model 从哪里加载 -- 每轮 draft 几步 -- 每轮最多提议多少 draft token - - -### 1.2 `NEXTN` 在 SGLang 里的真实含义 - -很多人第一次看会以为 `NEXTN` 是一条完全独立的 speculative runtime。 -实际上不是。 - -在: - -- `sglang/python/sglang/srt/server_args.py` -- `_handle_speculative_decoding()` 逻辑中 - -有一个关键规范化: - -- `NEXTN -> EAGLE` - -对应代码位置: - -- `server_args.py` 约 `2680-2681` - -这意味着: - -- 用户在命令行里写 `--speculative-algorithm NEXTN` -- 进入运行时后,SGLang 会把它归并到 `EAGLE` 这套 speculative worker 流程里 - -也就是说: - -- `NEXTN` 更像是“draft model 形态 / 语义” -- `EAGLE` 更像是“runtime orchestration 机制” - - -### 1.3 spec v2 与 overlap scheduler - -仍然是在: - -- `server_args.py` -- `_handle_speculative_decoding()` - -关键逻辑位置: - -- `2696-2716` - -SGLang 会做一件重要的系统级决策: - -- 如果 speculative 算法属于 `EAGLE / EAGLE3 / STANDALONE` -- 且环境变量 `SGLANG_ENABLE_SPEC_V2=True` -- 则开启 overlap schedule(即 spec v2) - -否则: - -- 会退回到不带 overlap 的传统路径(可以理解为 spec v1) - -同时还有一些额外约束: - -- spec v2 目前只支持 `topk = 1` -- 使用 speculative 时会关闭 mixed chunked prefill - - -### 1.4 DeepSeek / MTP 与 `speculative_draft_model_path` - -同一段逻辑里还有一个对 DeepSeek 很关键的行为: - -- 对 `DeepseekV3ForCausalLM`、`DeepseekV32ForCausalLM`、`GlmMoeDsaForCausalLM` - 等架构 -- 如果没有显式传 `speculative_draft_model_path` -- 会自动把它设成主模型路径 - -关键位置: - -- `server_args.py` 约 `2725-2748` - -这就是为什么日志里会有类似: - -- `DeepSeek MTP does not require setting speculative_draft_model_path.` - -的提示。 - -这说明 SGLang 把 DeepSeek MTP / NextN 看成是某种“和 target 模型强绑定”的 -draft 形态,而不是完全独立的小模型。 - - -## 2. 算法层:SpeculativeAlgorithm 与 worker 工厂 - -核心文件: - -- `sglang/python/sglang/srt/speculative/spec_info.py` - -### 2.1 算法枚举 - -关键枚举: - -- `SpeculativeAlgorithm` - -包含: - -- `EAGLE` -- `EAGLE3` -- `STANDALONE` -- `NGRAM` -- `NONE` - -关键位置: - -- `spec_info.py` 约 `15-23` - - -### 2.2 worker 工厂 - -最关键的方法: - -- `SpeculativeAlgorithm.create_worker()` - -关键位置: - -- `spec_info.py` 约 `52-105` - -这个函数负责把: - -- 算法类型 -- overlap 是否开启 -- multi-layer eagle 是否开启 - -映射成具体 worker 类。 - -典型映射关系: - -- `EAGLE + overlap` -> `EAGLEWorkerV2` -- `EAGLE + no overlap` -> `EAGLEWorker` -- `STANDALONE + overlap` -> `StandaloneWorkerV2` -- `STANDALONE + no overlap` -> `StandaloneWorker` -- `NGRAM` -> `NGRAMWorker` - - -### 2.3 什么叫 “supports_spec_v2” - -还有个很重要的方法: - -- `supports_spec_v2()` - -关键位置: - -- `spec_info.py` 约 `49-50` - -含义是: - -- 当前算法是否支持 overlap/spec v2 抽象 - -目前只有: - -- `EAGLE` -- `STANDALONE` - -对应为真。 - - -## 3. 调度层:Scheduler 如何把普通模型调度切成 speculative 调度 - -核心文件: - -- `sglang/python/sglang/srt/managers/scheduler.py` - -### 3.1 初始化顺序 - -关键位置: - -- `maybe_init_draft_worker()`:约 `527-554` -- `init_model_worker()`:约 `556-564` - -逻辑顺序是: - -1. 先建 `tp_worker` -2. 如果 speculative 开启,再建 `draft_worker` -3. 决定 `self.model_worker` 指向谁 - -代码语义: - -- 没开 speculative: - - `self.model_worker = self.tp_worker` -- 开了 speculative: - - `self.model_worker = self.draft_worker` - - -### 3.2 为什么 `self.model_worker = self.draft_worker` - -这里名字非常容易误导。 - -`scheduler.draft_worker` 并不一定是一个“纯 draft model worker”,它更像是: - -- speculative orchestrator - -例如: - -- `EAGLEWorker` -- `EAGLEWorkerV2` - -也就是说: - -- scheduler 并不是“把 target worker 替换掉了” -- 而是把执行入口切到了一个能同时协调 target + draft 的总控 worker - - -### 3.3 `run_batch()` 的差异 - -关键位置: - -- `scheduler.py` 约 `2360-2426` - -这里能看出 v1 与 v2 在 batch 抽象上的差异: - -- 开 overlap/spec v2 时: - - `worker_batch_or_batch = batch.get_model_worker_batch()` - - 下游主要处理 `ModelWorkerBatch` -- 非 overlap 的传统 speculative v1: - - 会直接把 `ScheduleBatch` 传给 `model_worker.forward_batch_generation()` - -这也是为什么你有时候会看到: - -- 有的 speculative worker 收的是 `ScheduleBatch` -- 有的 speculative worker 收的是 `ModelWorkerBatch` - -这不是 bug,而是新老抽象并存。 - - -## 4. Worker 层:target worker、draft worker 与 orchestrator 的关系 - -### 4.1 普通 target worker - -核心文件: - -- `sglang/python/sglang/srt/managers/tp_worker.py` - -`TpModelWorker` 是普通模型执行单元,负责: - -- 初始化 `ModelConfig` -- 初始化 `ModelRunner` -- 提供 `forward_batch_generation()` - -关键位置: - -- `_init_model_config()`:约 `320-336` -- `_init_model_runner()`:约 `338-358` -- `forward_batch_generation()`:约 `442+` - - -### 4.2 target 和 draft 的分流在哪里发生 - -在 `TpModelWorker._init_model_config()` 中: - -- 如果 `is_draft_worker=False`,用主模型路径 -- 如果 `is_draft_worker=True`,用 `speculative_draft_model_path` - -关键位置: - -- `tp_worker.py` 约 `323-336` - -这就是 target 和 draft 最底层模型配置分流的地方。 - - -### 4.3 `EAGLEWorker`(v1) - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_worker.py` - -`EAGLEWorker` 的特点: - -- 自己继承自 `TpModelWorker` -- 运行时同时持有: - - target worker - - 自己这套 draft model runner - -其 `forward_batch_generation()` 的大致逻辑是: - -- 如果是 extend: - - 先 `forward_target_extend` - - 再 `forward_draft_extend` -- 如果是 decode: - - 先 `draft()` - - 再 `verify()` - - 再 `forward_draft_extend_after_decode()` - -关键位置: - -- `eagle_worker.py` 约 `278-337` - - -### 4.4 `EAGLEWorkerV2`(spec v2 / overlap) - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_worker_v2.py` - -v2 与 v1 最大的结构差异是: - -- 外层 `EAGLEWorkerV2` 是 orchestrator -- 内层还有一个 `EagleDraftWorker` -- `EagleDraftWorker` 再内嵌一个真正的 draft `TpModelWorker` - -关键类: - -- `EagleDraftWorker`:约 `82` -- `EAGLEWorkerV2`:约 `607` - -这层设计的意义是: - -- 把 draft 逻辑进一步模块化 -- 更方便做 overlap 和独立的 draft graph / backend 管理 - - -### 4.5 `StandaloneWorkerV2` - -核心文件: - -- `sglang/python/sglang/srt/speculative/standalone_worker_v2.py` - -它和 `EAGLEWorkerV2` 的主要区别不是调度框架,而是: - -- draft model 不再共享 target 的 embedding / lm_head - -在源码里可以看到: - -- `StandaloneDraftWorker.init_lm_head()` 明确覆写为空实现 - -也就是: - -- standalone draft 用自己的一套 embedding/head -- 不走与 target 的共享逻辑 - - -## 5. 模型配置与 draft 架构改写 - -核心文件: - -- `sglang/python/sglang/srt/configs/model_config.py` - -### 5.1 `ModelConfig.from_server_args()` - -这是 target / draft `ModelConfig` 的统一入口。 - -关键位置: - -- `from_server_args()`:约 `238+` - - -### 5.2 `_config_draft_model()` - -最关键的方法: - -- `_config_draft_model()` - -关键位置: - -- `model_config.py` 约 `277-340` - -对于 DeepSeek: - -- 若 `is_draft_model=True` -- 且原始架构是 `DeepseekV3ForCausalLM` - -就会改写为: - -- `DeepseekV3ForCausalLMNextN` - -这是 draft 侧为什么会变成 NextN 壳子的核心原因。 - - -### 5.3 这和 ATOM plugin 的关系 - -这也是当前 `ATOM plugin` 只接管 target、没接管 draft 的根源: - -- ATOM external model package 只导出了 `DeepseekV3ForCausalLM` -- 并没有导出 `DeepseekV3ForCausalLMNextN` - -所以最终结果是: - -- target `DeepseekV3ForCausalLM` 被 external package 覆盖 -- draft `DeepseekV3ForCausalLMNextN` 仍走 upstream SGLang native - - -## 6. 模型实例化链路 - -### 6.1 `ModelRunner.load_model()` - -核心文件: - -- `sglang/python/sglang/srt/model_executor/model_runner.py` - -关键位置: - -- `load_model()`:约 `901-991` - -这里完成: - -- 构造 `LoadConfig` -- 选择 model loader -- 调 `loader.load_model(...)` - - -### 6.2 `ModelRunner._get_attention_backend()` - -关键位置: - -- `model_runner.py` 约 `1736-1746` - -这里会根据: - -- 是否是 draft worker -- 是否设置了 `speculative_draft_attention_backend` - -来决定 draft 用哪种 attention backend。 - -这是 speculative 与 attention backend 结合的一个重要入口。 - - -### 6.3 `_initialize_model()` - -核心文件: - -- `sglang/python/sglang/srt/model_loader/loader.py` - -关键位置: - -- `_initialize_model()`:约 `257-277` - -这是底层真正 `return model_class(**kwargs)` 的地方。 - -也就是说: - -- target 和 draft 在上层是两个不同 worker / model config -- 但底层最终都汇合到同一个模型实例化函数 - - -### 6.4 `get_model_architecture()` - -核心文件: - -- `sglang/python/sglang/srt/model_loader/utils.py` - -关键位置: - -- `get_model_architecture()`:约 `89-119` - -这个函数负责: - -- 看 `hf_config.architectures` -- 查 `ModelRegistry` -- 最终选出要实例化的 model class - -如果 external package 覆盖了同名架构,就会优先拿 external package 的类。 - - -## 7. 三层 batch 数据结构 - -核心文件: - -- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` -- `sglang/python/sglang/srt/managers/schedule_batch.py` - -### 7.1 三层抽象 - -`forward_batch_info.py` 文件开头就写得很清楚: - -- `ScheduleBatch -> ModelWorkerBatch -> ForwardBatch` - -含义: - -- `ScheduleBatch` - - scheduler 侧高层调度状态 - - CPU 语义更强 -- `ModelWorkerBatch` - - 给 worker 的中间态 -- `ForwardBatch` - - 最接近 kernel / backend 执行的低层态 - - -### 7.2 `ForwardMode` - -关键位置: - -- `forward_batch_info.py` 约 `74-179` - -推测相关最重要的几个 mode: - -- `TARGET_VERIFY` -- `DRAFT_EXTEND` -- `DRAFT_EXTEND_V2` -- `DECODE` -- `EXTEND` - -这里有个容易踩坑的点: - -- `TARGET_VERIFY` 在 `is_extend()` 里返回真 - -所以如果 backend 只按 “decode vs extend” 粗暴分流,很容易把 verify 错当普通 extend。 - - -### 7.3 `ScheduleBatch.get_model_worker_batch()` - -核心文件: - -- `sglang/python/sglang/srt/managers/schedule_batch.py` - -关键位置: - -- `get_model_worker_batch()`:约 `2175-2228` - -这一步负责把 scheduler 层状态打包成 `ModelWorkerBatch`。 - -关键理解: - -- 对 `decode_or_idle()`,`extend_seq_lens` 会被设成 `None` -- 对其他 extend 类路径,`extend_seq_lens` 来自 `self.extend_lens` - -这也是后面 verify 路径里经常出现 `extend_seq_lens=None` 的背景。 - - -## 8. speculative 的核心数据结构:SpecInput - -核心文件: - -- `sglang/python/sglang/srt/speculative/spec_info.py` - -关键抽象: - -- `SpecInput` -- `SpecInputType` - -类型包括: - -- `EAGLE_DRAFT` -- `EAGLE_VERIFY` -- `NGRAM_VERIFY` - -也就是说,speculative 不只是“多传几个 tensor”,而是有一套专门的数据结构协议。 - - -### 8.1 DeepSeek / EAGLE 相关具体实现 - -主要文件: - -- `sglang/python/sglang/srt/speculative/eagle_info.py` -- `sglang/python/sglang/srt/speculative/eagle_info_v2.py` - -这些文件负责: - -- draft 输入构造 -- verify 输入构造 -- draft token / hidden state / custom mask / positions 等 speculative 元数据 - - -## 9. EAGLE v1 的主执行流程 - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_worker.py` - -### 9.1 extend / prefill 阶段 - -关键位置: - -- `forward_batch_generation()`:约 `278-309` - -流程: - -1. target 先跑 extend / prefill -2. target 产出 hidden state -3. draft 用 target hidden state 再做 draft extend - - -### 9.2 decode 阶段 - -关键位置: - -- `forward_batch_generation()`:约 `310-337` - -流程: - -1. draft 先 propose -2. target 再 verify -3. draft 根据 verify 结果再 extend,为下一轮准备 - -这是一个典型的: - -- `draft -> target verify -> draft extend` - -链式协作过程。 - - -### 9.3 why target and draft share embed/head - -关键位置: - -- `eagle_worker.py` 约 `157-183` - -这里会显式调用: - -- `target_worker.model_runner.model.get_embed_and_head()` -- `draft_model_runner.model.set_embed_and_head(...)` - -说明: - -- upstream 的 EAGLE/NextN draft 设计默认依赖 target 的 embedding 和 lm_head - - -## 10. EAGLE v2 的主执行流程 - -核心文件: - -- `sglang/python/sglang/srt/speculative/eagle_worker_v2.py` - -### 10.1 prefill / extend - -关键位置: - -- `forward_batch_generation()`:约 `673-697` - -流程: - -1. target prefill -2. draft prefill -3. 返回 `next_draft_input` - - -### 10.2 decode / verify - -关键位置: - -- `forward_batch_generation()`:约 `698-722` -- `verify()`:约 `724-780` - -流程: - -1. `draft_worker.draft()` 生成 `EagleVerifyInput` -2. `verify()` 内部构造 verify forward batch -3. target 执行 verify 前向 -4. draft 再做 `_draft_extend_for_decode()` - - -### 10.3 spec v2 的一个核心特征 - -它不再直接围绕 `ScheduleBatch` 做所有 speculative 逻辑,而是更偏向: - -- `ModelWorkerBatch` -- `next_draft_input` -- `future_indices` -- overlap plan stream - -这也是它和 v1 最大的结构差异。 - - -## 11. attention backend 如何感知 speculative - -最典型的文件: - -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - -### 11.1 `init_forward_metadata()` - -关键位置: - -- `aiter_backend.py` 约 `435+` - -这段代码是理解 speculative attention 路径最关键的入口之一。 - -它不是简单区分: - -- decode -- extend - -而是细分: - -- `decode_or_idle` -- `draft_extend` -- `target_verify` -- 普通 extend - - -### 11.2 `draft_extend` - -关键位置: - -- `aiter_backend.py` 约 `526-606` - -特点: - -- 通过 `spec_info.generate_attn_arg_prefill(...)` - 来生成 draft extend 所需的 attention 参数 -- 对 MLA 与非 MLA 路径分别处理 - - -### 11.3 `target_verify` - -关键位置: - -- `aiter_backend.py` 约 `607+` - -特点: - -- 不依赖普通 extend 的 `extend_seq_lens` -- 直接根据: - - `spec_info.draft_token_num` - - `forward_batch.seq_lens` - -构造 verify 所需的: - -- `qo_indptr` -- `kv_indptr` -- `kv_indices` - -这是后续排查 `ATOM plugin verify` 问题时最值得对照的一段。 - - -## 12. v1 与 v2 的差异总结 - -### 12.1 v1 - -特点: - -- 更偏 `ScheduleBatch` -- speculative worker 逻辑更集中在一个大类里 -- 以串行 orchestrate 为主 - - -### 12.2 v2 - -特点: - -- 依赖 overlap scheduler -- 更偏 `ModelWorkerBatch` -- 引入 `next_draft_input` -- 更明显地区分 draft worker 与 orchestrator worker -- 更容易做 plan stream / overlap - - -### 12.3 对调试的实际影响 - -如果你调试 speculative 问题,一定要先分清: - -- 当前是 v1 还是 v2 -- 当前 `model_worker.forward_batch_generation(...)` - 收到的是 `ScheduleBatch` 还是 `ModelWorkerBatch` - -否则很容易误判字段来源和生命周期。 - - -## 13. 一份推荐阅读顺序 - -如果之后需要重新从头理解 SGLang speculative decoding,建议按下面顺序读: - -1. `server_args.py` - - 看 speculative 参数、NEXTN 规范化、spec v2 开关 -2. `spec_info.py` - - 看算法枚举和 worker 工厂 -3. `scheduler.py` - - 看 `maybe_init_draft_worker()` / `init_model_worker()` / `run_batch()` -4. `tp_worker.py` - - 看 target 与 draft `ModelConfig` 的分流 -5. `model_config.py` - - 看 `_config_draft_model()` -6. `deepseek_nextn.py` - - 看 `DeepseekV3ForCausalLMNextN` 到底长什么样 -7. `eagle_worker.py` - - 看传统 speculative v1 主流程 -8. `eagle_worker_v2.py` - - 看 overlap/spec v2 主流程 -9. `schedule_batch.py` - - 看 `ScheduleBatch -> ModelWorkerBatch` -10. `forward_batch_info.py` - - 看 `ForwardMode` 和 `ForwardBatch` -11. `aiter_backend.py` - - 看 speculative attention metadata 怎么初始化 - - -## 14. 对 ATOM plugin 调试的启发 - -这份背景知识对 `ATOM + SGLang plugin` 的调试最直接的启发有三点: - -### 启发 1 - -不能只盯 model class,还要盯: - -- `ServerArgs` -- `ModelConfig` -- `Scheduler` -- `TpModelWorker` - -因为 draft/target 的分流在这些层已经决定了。 - - -### 启发 2 - -如果 plugin 只覆盖了: - -- `DeepseekV3ForCausalLM` - -但没有覆盖: - -- `DeepseekV3ForCausalLMNextN` - -那么最终一定会形成: - -- target 走 plugin -- draft 走 upstream - -的混合运行形态。 - - -### 启发 3 - -如果 attention backend 只按: - -- decode -- extend - -粗暴分流,而没有补: - -- `target_verify` -- `draft_extend` - -这类 speculative 专有 metadata 路径, -那么在 speculative 模式下一定迟早会在 verify / draft_extend 里出错。 - - -## 15. 最终总结 - -SGLang 的 speculative decoding 并不是一个局部 feature,而是一套完整的运行时体系: - -- 在配置层决定算法和 draft model 语义 -- 在 model config 层改写 draft 架构 -- 在 scheduler 层切换到 speculative orchestrator -- 在 worker 层维护 target / draft 协作 -- 在 batch 层用三层结构管理状态转换 -- 在 attention backend 层按 `ForwardMode` 和 `SpecInput` - 细化 metadata 初始化 - -如果后续要把 `ATOM MTP` 真正接到 plugin 路径上,最重要的不是先改单个 kernel, -而是先把这张图看清楚: - -- 谁负责 target -- 谁负责 draft -- 谁负责调度 -- 哪些数据结构在层层转换 -- speculative 特有的 `target_verify` / `draft_extend` - 在 attention/backend 层是如何被建模的 - -只有在这个架构认知稳定之后,后面的接入和调试才会高效。 diff --git a/work_log/MTP/2026-04-08-vllm-continuous-batching.md b/work_log/MTP/2026-04-08-vllm-continuous-batching.md deleted file mode 100644 index 2e643b89e..000000000 --- a/work_log/MTP/2026-04-08-vllm-continuous-batching.md +++ /dev/null @@ -1,1214 +0,0 @@ -# 2026-04-08 vLLM Continuous Batching 原理与源码位置笔记 - -## 文档目的 - -这份文档用于系统梳理 `vLLM` 中与 `continuous batching` 相关的核心机制, -重点回答下面几个问题: - -- `continuous batching` 到底是什么,和静态 batch 有什么本质差异 -- `scheduler` 是如何把多个 request 组装成一次 step 的执行 batch 的 -- 一次 step 之后,模型输出是如何再放回各个 request 的 -- 为什么 vLLM 的 batch 不是传统训练里那种规则的 `B x L` -- `prefill / decode / chunked prefill / speculative decode` 在这个框架下是怎样统一的 -- 如果要顺着源码看,应该优先看哪些代码位置 - -本文尽量从三个维度同时展开: - -- 系统设计 -- 张量 shape -- 源码入口 - -方便后续复盘、调试和与其他推理框架做对比。 - - -## 一句话总结 - -`vLLM continuous batching` 的核心不是“把很多请求 pad 成一个固定 `B x L` 大矩阵”, -而是: - -- 每个 step 动态决定“每个 request 这一步前进多少 token” -- 把这些 token 展平成一个按 token 计数的 flat batch -- 用 `query_start_loc / seq_lens / block_tables / slot_mapping` 等元数据告诉 GPU: - - 每个 token 属于哪个 request - - 它在该 request 中的位置是多少 - - 它应该读写哪一段 KV cache -- step 结束后,再通过 `req_id_to_index` 把输出准确拆回每个 request - -可以把它概括成: - -```text -requests - -> scheduler 决定本步每个 request 的 n_i - -> 组装成 flat token batch, 总 token 数 T = sum(n_i) - -> GPU forward / sample - -> 用 req_id_to_index 把输出拆回 request -``` - - -## 版本说明 - -本文主要覆盖两条线: - -- `v1 / current main` 风格实现 -- `v0` 旧版 `LLMEngine` / `SequenceGroup` 风格实现 - -需要注意: - -- `v1` 是当前更值得优先看的主线 -- `v0` 仍然有很强的参考价值,因为很多文章、issue、历史讨论都仍然沿用 - `SequenceGroup`、`SchedulerOutputs` 这套命名 -- 本文提到的源码位置与行号,基于 `2026-04-08` 抓取的 upstream 快照, - 后续可能会轻微漂移 - - -## 1. 为什么 continuous batching 不是静态 batching - -### 1.1 静态 batching 的直觉 - -训练或普通离线推理中,大家更熟悉的是: - -- 给定一批序列 -- pad 到同一个长度 -- 形成一个规则张量 - -例如: - -```text -input_ids.shape = [B, L] -attention_mask.shape = [B, L] -``` - -这种做法的假设是: - -- 这一批样本一起开始 -- 一起执行 -- 一起结束 - - -### 1.2 在线 serving 的问题 - -在线服务时,请求并不是同时到达,也不会同时结束。 - -典型情况是: - -- 某些 request 还在做长 prompt 的 prefill -- 某些 request 已经进入 decode,每步只需要前进 1 个 token -- 某些 request 刚结束 -- 某些新 request 又刚进入系统 - -如果还坚持用静态 batch,就会遇到: - -- 等待新 request 凑满 batch,TTFT 变差 -- 某个 request 提前结束后,batch 中留下空洞 -- prompt 很长的 request 会拖慢所有其他 request - - -### 1.3 continuous batching 的本质 - -所以 vLLM 的选择不是“固定一批 request 一起跑到结束”,而是: - -- 每个调度 step 都重新看当前系统中的 request -- 决定本步哪些 request 参与 -- 决定每个 request 本步前进多少 token -- step 结束后,再立刻重组下一轮 batch - -因此,batch 是“连续流动”的。 - -这也是 `continuous batching` 这个名字的真正含义。 - - -## 2. vLLM 视角下一个 request 的核心状态 - -在 vLLM 中,理解 request 的关键不是先区分“prefill 还是 decode”, -而是先看下面几个状态量。 - -### 2.1 最重要的两个量 - -- `all_token_ids` - - 当前这个 request 已知的完整 token 序列 - - 包括 prompt token,也包括已经生成出来但可能尚未被下一轮 compute 的 token -- `num_computed_tokens` - - `all_token_ids` 中已经真正做过 forward、对应 KV 已经落到 cache 的前缀长度 - -于是: - -- 如果 `num_computed_tokens = 0`,说明 prompt 还没 prefill -- 如果 `num_computed_tokens < len(all_token_ids)`,说明还有 backlog 没算 -- decode 阶段常见情况是: - - 上一轮 sample 出了 1 个新 token - - 下一轮需要把这 1 个 token 真正送进模型计算 - - 所以通常 backlog 是 1 - - -### 2.2 统一 prefill / decode 的关键观察 - -从 scheduler 角度,并不存在一个特别刚性的: - -- “prefill phase” -- “decode phase” - -更接近的真实逻辑是: - -- 对每个 request,看它还有多少 token 没被 compute -- 本轮决定从这些 backlog 中取多少 token 来执行 - -因此: - -- 新 request 的 backlog 通常很大,对应 prefill -- 老 request 的 backlog 常常只有 1,对应 decode -- chunked prefill 只是“长 request 的 backlog 一次不要全吃完” - -也就是说,`prefill / decode` 更像是同一调度框架下的两种常见形态。 - - -### 2.3 KV cache 也是 request 状态的一部分 - -除了 token 序列本身,每个 request 还绑定: - -- KV cache block -- block table -- 对应的 slot mapping - -这决定了: - -- decode 时虽然本轮可能只新输入 1 个 token -- 但模型仍然能通过 KV cache 读取全部历史上下文 - -所以一个 request 的有效状态并不只是 token ids,而是: - -```text -request state - = token sequence - + num_computed_tokens - + sampling / stop state - + KV cache mapping - + (可选)LoRA / multimodal / structured output 状态 -``` - - -## 3. 一个 batch 的“实质性内容”到底是什么 - -这是最容易误解的地方。 - -### 3.1 从 tokenizer 语义看 - -一个 token 最原始确实就是一个 vocab id,也就是一个整数。 - -例如: - -```text -"hello" -> 15496 -``` - -因此在输入层面,`input_ids` 的每个元素确实就是“一个数字”。 - - -### 3.2 从 GPU 执行看 - -但真正送进 GPU 跑一次 step,远远不只有 `input_ids`。 - -至少还需要: - -- `input_ids` -- `positions` -- `query_start_loc` -- `seq_lens` -- `block_tables` -- `slot_mapping` -- `logits_indices` - -特殊情况下还会有: - -- `inputs_embeds` -- multimodal encoder 相关输入 -- LoRA metadata -- speculative decode 的 draft token 相关索引 -- structured output grammar bitmask - -所以如果问: - -> 一个 batch 的实质性 token,是不是仅仅是简单的 input_id? - -答案是: - -- 对“token 身份”来说,最原始确实是 `input_id` -- 对“一次 forward 的完整执行语义”来说,绝对不够 - -因为模型还必须知道: - -- 这个 token 属于哪个 request -- 它在该 request 里的绝对位置是多少 -- 它应该从 KV cache 的哪一段读取历史上下文 - - -### 3.3 这个 `input_id` 在 GPU 上吗 - -在真正执行时,是的。 - -更准确地说: - -- request 和 scheduler 主要在 CPU 侧维护高层状态 -- 但进入 worker / model runner 后,本步需要用到的 - `input_ids / positions / query_start_loc / block_tables / slot_mapping` - 会被放入 GPU buffer -- 然后做 embedding lookup,进入 transformer forward - -所以: - -- 逻辑上的 token id 一开始常出现在 CPU 侧 -- 本步执行用到的 `input_ids` 会进入 GPU -- 进入模型后,它很快会被 embedding 成一个向量 - -例如: - -```text -token_id: scalar - -> embedding lookup - -> hidden vector: [hidden_size] -``` - -如果本轮总共执行 `T` 个 token,那么 embedding 后大致就是: - -```text -[T, hidden_size] -``` - - -## 4. scheduler 到底在做什么 - -### 4.1 核心目标 - -scheduler 的工作不是“把所有 request pad 成一个矩阵”,而是: - -- 从 `waiting / running` 队列里挑 request -- 决定每个 request 本步前进多少 token -- 保证不超出资源预算 -- 必要时做 preemption -- 产出 worker 能执行的调度结果 - - -### 4.2 主要约束 - -在 `v1` 中,最重要的两个约束是: - -- `max_num_seqs` - - 本步最多同时挂多少个 request -- `max_num_batched_tokens` 或 `max_num_scheduled_tokens` - - 本步总共最多前进多少 token - -此外还会考虑: - -- model max length -- encoder 计算预算(多模态) -- LoRA 同批数量限制 -- KV cache block 是否足够 -- prefix cache / remote KV / async loading 状态 - - -### 4.3 调度的高层流程 - -`v1` 的 `Scheduler.schedule()` 大致可以概括成: - -1. 先尝试调度 `running` request -2. 再尝试从 `waiting` 里吸入新 request -3. 对每个 request 决定本步的 `n_i` -4. 维护 `token_budget` -5. 如果 block 不够或约束冲突,必要时 preempt 某些 request -6. 输出 `SchedulerOutput` - -一个很关键的设计点是: - -> scheduler 关心的是 “本步每个 request 前进多少 token” - -而不是: - -> “这个 request 属于 prefill 还是 decode 类别” - - -### 4.4 chunked prefill 是怎样融进去的 - -长 prompt 的 request,如果一次全吃完会把 token budget 吃光, -拖累其他 request。 - -所以 vLLM 会在需要时把它拆成多步: - -- 本轮只 prefill prompt 的一部分 -- 剩下的下轮再继续 - -因此一个 request 可以出现: - -- 还在 prefill chunk 中 -- 但同时其他 request 已经在 decode - -这正是 continuous batching 最典型的混合场景。 - - -## 5. scheduler 输出的关键数据结构 - -在 `v1` 中,scheduler 输出的核心抽象是: - -- `NewRequestData` -- `CachedRequestData` -- `SchedulerOutput` - -可以把这三者理解成: - -- `NewRequestData` - - 首次进入 worker 的 request,要发送完整初始化数据 -- `CachedRequestData` - - worker 已经缓存过的 request,只发送增量信息 -- `SchedulerOutput` - - 这一步所有调度决策的总封装 - - -### 5.1 `NewRequestData` - -它通常包含: - -- `req_id` -- `prompt_token_ids` -- `sampling_params` -- `pooling_params` -- `block_ids` -- `num_computed_tokens` -- `lora_request` -- `prefill_token_ids`(v2 model runner 相关) - -也就是说,新 request 第一次进入 worker 时,需要把足够多的静态信息发过去, -让 worker 端建立自己的 request cache。 - - -### 5.2 `CachedRequestData` - -这个结构是 continuous batching 很重要的一环,因为它体现了: - -> worker 对 request 状态是“长期缓存”的,而不是每 step 重建。 - -典型字段有: - -- `req_ids` -- `resumed_req_ids` -- `new_token_ids` -- `all_token_ids` -- `new_block_ids` -- `num_computed_tokens` -- `num_output_tokens` - -其中最关键的思想是: - -- 对已经在 worker 里的 request,不重复发送整条 request -- 只发送变化的部分 - -这能显著减少调度端和 worker 之间的通信成本。 - - -### 5.3 `SchedulerOutput` - -最重要的字段有: - -- `scheduled_new_reqs` -- `scheduled_cached_reqs` -- `num_scheduled_tokens: dict[req_id, int]` -- `total_num_scheduled_tokens` -- `scheduled_spec_decode_tokens` -- `scheduled_encoder_inputs` -- `finished_req_ids` - -其中: - -- `num_scheduled_tokens` 是整轮 step 的核心 -- 它表达的是: - - 这个 request 这一步要前进几个 token - -如果把本轮调度到了 `B` 个 request,则: - -```text -num_scheduled_tokens: {req_id_1: n_1, ..., req_id_B: n_B} -T = n_1 + ... + n_B -``` - -这里: - -- `B` 是 request 数 -- `T` 是 token 数 - -vLLM 后续执行更偏向围绕 `T` 展开,而不是围绕规则的 `B x L`。 - - -## 6. 真正执行时,batch 的 shape 长什么样 - -这是理解 vLLM 最关键的一节。 - -### 6.1 不是 `[B, L]`,而是 token-flat `[T]` - -假设本轮有 `B` 个 request,第 `i` 个 request 本轮前进 `n_i` 个 token。 - -则: - -```text -T = sum_i n_i -``` - -执行时,最核心的输入通常是: - -- `input_ids`: `[T]` -- `positions`: `[T]` -- `query_start_loc`: `[B + 1]` -- `seq_lens`: `[B]` - -为了 CUDA graph 或执行约束,vLLM 里还常会有 padding 后版本: - -- `T_pad` -- `B_pad` - -于是实际 buffer 常是: - -- `input_ids`: `[T_pad]` -- `positions`: `[T_pad]` -- `seq_lens`: `[B_pad]` - - -### 6.2 `query_start_loc` 是什么 - -`query_start_loc` 是每个 request 在扁平 token buffer 中的边界。 - -如果: - -```text -n = [1, 1, 2, 6] -``` - -则: - -```text -query_start_loc = [0, 1, 2, 4, 10] -``` - -含义是: - -- 第 0 个 request 用 `input_ids[0:1]` -- 第 1 个 request 用 `input_ids[1:2]` -- 第 2 个 request 用 `input_ids[2:4]` -- 第 3 个 request 用 `input_ids[4:10]` - -这就是: - -- 一个大 flat token buffer -- 加一个分段索引数组 - -共同表达 ragged batch 的典型做法。 - - -### 6.3 decode 为何也能放进这个框架 - -decode request 在本轮通常只前进 1 个 token,所以常见: - -```text -n_i = 1 -``` - -那它在扁平 batch 里也就只占一个元素。 - -例如: - -```text -input_ids = [r1_new, r2_new, p0, p1, p2, p3] -``` - -这里前两个是 decode token,后四个是某个 prefill request 的 prompt chunk。 - -看起来 decode token 很“短”,但它并不缺上下文,因为上下文来自: - -- `seq_lens` -- `block_tables` -- `slot_mapping` -- KV cache - - -### 6.4 还有哪些重要 shape - -除了上面几个,attention 执行时还非常依赖: - -- `block_tables` - - 近似可以看成:每个 KV cache group 一份 - - 形状常见近似为 `[B_pad, max_num_blocks]` -- `slot_mapping` - - 近似为每个 token 映射到哪个 KV slot - - 常见近似为 `[T_pad]` - -因此,从 GPU 视角看,一个 batch 更接近: - -```text -flat token payload - + per-request segmentation metadata - + KV cache address metadata -``` - -而不是简单的 `input_ids` 矩阵。 - - -## 7. 一次 step 的完整生命周期 - -在 `v1` 中,可以把一次 step 概括为: - -1. `schedule()` -2. `execute_model(...)` -3. `update_from_output(...)` -4. `OutputProcessor.process_outputs(...)` - -下面按顺序拆开。 - - -### 7.1 `schedule()` - -scheduler 产生: - -- 哪些 request 参与本轮 -- 每个 request 本轮前进多少 token -- 新 request / cached request 的增量更新数据 - -同时,vLLM 还有一个很值得注意的设计: - -- request 被 schedule 到后,会先把 `num_computed_tokens` 往前推进 -- 这样它可以在下一轮继续被及时调度 -- 如果后面 speculative token 有拒绝,再在 `update_from_output()` 里回调修正 - -这说明: - -- 调度状态和最终 sample 结果之间不是完全同步的 -- 某些统计量会“先乐观推进,再按输出修正” - - -### 7.2 `execute_model(...)` - -worker 侧收到 `SchedulerOutput` 后,会做: - -- `add_requests()` - - 初始化首次进入 worker 的 request -- `update_requests()` - - 更新已有 request 的 block / token 等状态 -- `prepare_inputs()` - - 组装本轮 flat token batch -- `prepare_attn()` - - 生成 attention metadata -- 执行模型 forward -- sample token / 或做 pooling - -执行完成后返回 `ModelRunnerOutput`。 - - -### 7.3 `ModelRunnerOutput` - -这是“从 GPU 结果回到 scheduler”的关键桥梁。 - -核心字段可以理解成: - -- `req_ids`: `[B]` -- `req_id_to_index: {req_id -> batch_idx}` -- `sampled_token_ids`: `list[list[int]]` -- `logprobs` -- `prompt_logprobs_dict` -- `pooler_output` - -其中最关键的是: - -- `req_id_to_index` -- `sampled_token_ids` - -因为 worker 为了执行效率可能重排 request 顺序,所以 scheduler 回填时不能假设: - -- “第 0 个输出一定属于第 0 个 request” - -而必须显式做: - -```text -idx = req_id_to_index[req_id] -generated = sampled_token_ids[idx] -``` - - -### 7.4 `update_from_output(...)` - -这一步负责: - -- 根据 `req_id_to_index` 找到每个 request 对应的输出 -- 把 `sampled_token_ids[idx]` 回填到 request 状态 -- 检查 stop / eos / length -- 处理 speculative decode 的接受 / 拒绝 -- 必要时释放 request 的 KV cache -- 产出 `EngineCoreOutput` - -这里有一个非常重要的概念区分: - -- `n_i` - - 本轮这个 request 被安排去“计算”的 token 数 -- `g_i` - - 本轮真正“生成出来并回给请求”的 token 数 - -这两个量不一定相等。 - -典型例子: - -- chunked prefill 时,`n_i > 0`,但 `g_i = 0` -- 普通 decode 时,通常 `n_i = 1`,`g_i = 1` -- speculative decode 时,可能 `n_i = 1 + k`,而 `g_i` 可以大于 1 - - -### 7.5 `OutputProcessor.process_outputs(...)` - -这一步负责从 engine 内部输出变成用户能看到的 `RequestOutput`。 - -主要工作有: - -- detokenize -- stop string 检查 -- logprobs 处理 -- 组装 `RequestOutput` - -因此完整链路是: - -```text -SchedulerOutput - -> ModelRunnerOutput - -> EngineCoreOutput - -> RequestOutput -``` - - -## 8. 例子一:3 个 request,4 个 step - -下面给一个不带 speculative decode 的完整 toy example。 - -假设配置: - -- `max_num_seqs = 3` -- `max_num_batched_tokens = 6` -- 开启 `chunked prefill` - -4 个请求依次到达: - -- `R1 = [11,12,13,14,15]` -- `R2 = [21,22]` -- `R3 = [31,32,33,34]` -- `R4 = [41,42,43]` - -初始状态: - -```text -R1: all=[11,12,13,14,15], comp=0 -R2: all=[21,22], comp=0 -``` - - -### Step 0 - -scheduler 选择: - -- `R1` 前进 4 个 prompt token -- `R2` 前进 2 个 prompt token - -于是: - -```text -num_scheduled_tokens = {R1: 4, R2: 2} -B = 2 -T = 6 -``` - -worker 侧可能重排为: - -```text -req_ids = [R2, R1] # shape [2] -num_scheduled_tokens = [2, 4] # shape [2] -query_start_loc = [0, 2, 6] # shape [3] - -input_ids = [21,22, 11,12,13,14] # shape [6] -positions = [0,1, 0,1,2,3] # shape [6] -seq_lens = [2,4] # shape [2] -``` - -假设这一轮输出: - -- `R2` prompt 已结束,sample 到首个生成 token `23` -- `R1` 还没结束 prefill,没有生成 token - -则: - -```text -req_id_to_index = {R2: 0, R1: 1} -sampled_token_ids = [[23], []] -``` - -回填后: - -```text -R1: all=[11,12,13,14,15], comp=4 -R2: all=[21,22,23], comp=2 -``` - -注意: - -- `R2` 的 `23` 已追加到 `all_token_ids` -- 但 `comp=2`,因为本轮真正算进 KV 的还是原 prompt 的 2 个 token -- `23` 会在下一轮真正参与 decode compute - - -### Step 1 - -此时 `R3` 到达。 - -现在系统中: - -- `R1` 还差 1 个 prompt token -- `R2` 要 decode 它的 `23` -- `R3` 是新 request,需要 prefill - -于是 scheduler 可以在同一步混合调度: - -```text -num_scheduled_tokens = {R1: 1, R2: 1, R3: 4} -B = 3 -T = 6 -``` - -执行 batch: - -```text -req_ids = [R1, R2, R3] # shape [3] -query_start_loc = [0, 1, 2, 6] # shape [4] - -input_ids = [15, 23, 31,32,33,34] # shape [6] -positions = [4, 2, 0,1,2,3] # shape [6] -seq_lens = [5, 3, 4] # shape [3] -``` - -假设输出: - -```text -sampled_token_ids = [[101], [24], [35]] -``` - -回填后: - -```text -R1: all=[11,12,13,14,15,101], comp=5 -R2: all=[21,22,23,24], comp=3 -R3: all=[31,32,33,34,35], comp=4 -``` - -这一步非常重要,因为它体现了: - -- `R1` 还在补最后一段 prefill -- `R2` 已在 decode -- `R3` 是新 request 的整段 prefill - -三者可以在同一个 step 里并存。 - - -### Step 2 - -现在三者都进入正常 decode: - -```text -num_scheduled_tokens = {R1:1, R2:1, R3:1} -B = 3 -T = 3 - -input_ids = [101, 24, 35] # shape [3] -query_start_loc = [0, 1, 2, 3] # shape [4] -``` - -假设输出: - -```text -sampled_token_ids = [[102], [2], [36]] -``` - -若 `2` 是 EOS,则: - -- `R2` 完成 -- `R2` 的 KV block 可被释放 -- `R1`、`R3` 继续保留 - - -### Step 3 - -此时 `R4` 到达,于是可以立刻填补空位: - -```text -num_scheduled_tokens = {R1:1, R3:1, R4:3} -B = 3 -T = 5 - -input_ids = [102, 36, 41,42,43] # shape [5] -query_start_loc = [0, 1, 2, 5] # shape [4] -``` - -这就是 continuous batching 的最直观体现: - -- 老 request 结束后立即移出 -- 新 request 马上补进来 -- 系统不是“整批结束再换下一批”,而是每一步都在流动 - - -## 9. 例子二:为什么 `n_i` 不等于 `g_i` - -很多人第一次看时会默认: - -- scheduler 安排这个 request 算 4 个 token -- 那它就应该返回 4 个 token - -这在 vLLM 中并不成立。 - -### 9.1 chunked prefill 场景 - -假设一个长 prompt request: - -```text -prompt = [p0, p1, p2, p3, p4, p5, p6, p7] -num_computed_tokens = 0 -``` - -本轮只给它 4 个 token budget: - -```text -n_i = 4 -input_ids = [p0, p1, p2, p3] -``` - -如果 prompt 还没 prefill 完,则这轮: - -```text -g_i = 0 -``` - -也就是: - -- 本轮确实算了 4 个 token -- 但没有新生成 token 回给用户 - - -### 9.2 普通 decode 场景 - -如果一个 request 已经完成 prefill,只差 decode: - -```text -n_i = 1 -g_i = 1 -``` - -这是最常见、也最容易理解的情况。 - - -### 9.3 speculative decode 场景 - -如果开启 speculative decode,情况会变成: - -- 本轮可能先有若干 draft token -- target verify 后可能一次接受多个 token - -于是可能出现: - -```text -n_i = 1 + k -g_i = m -``` - -其中: - -- `k` 是 draft 相关的额外计算 -- `m` 是最终接受并回填的 token 数 -- `m` 可以大于 1 - - -## 10. 例子三:speculative decode 的 shape 直觉 - -假设某个 request 本轮有 3 个 draft token: - -```text -scheduled_spec_decode_tokens = { - R1: [501, 502, 503] -} -``` - -worker 执行后,假设 target verify: - -- 接受了前 2 个 draft token -- 然后再给出 1 个新的 target token - -则返回到 scheduler 的 `generated_token_ids` 可能近似为: - -```text -sampled_token_ids[idx_of_R1] = [501, 502, 900] -``` - -此时: - -- `num_draft_tokens = 3` -- `num_accepted = len(generated_token_ids) - 1 = 2` -- `num_rejected = 3 - 2 = 1` - -也就是说: - -- 被接受的 draft token 会直接作为 output 回填 -- 被拒绝的 draft token 要把之前乐观推进的 `num_computed_tokens` - 再修正回来 - -这也是为什么 scheduler 与 output update 之间会有一个“先推进、后修正”的配合。 - - -## 11. v0 和 v1 的关系 - -如果你看旧版文章,经常会看到: - -- `Sequence` -- `SequenceGroup` -- `ScheduledSequenceGroup` -- `RequestOutput.from_seq_group(...)` - -这主要是 `v0` 风格。 - -### 11.1 v0 的回填方式 - -旧版 `LLMEngine.step()` 的高层流程大致是: - -```text -scheduler.schedule() - -> model_executor.execute_model(...) - -> _process_model_outputs(...) - -> RequestOutput.from_seq_group(...) -``` - -这里的回填更偏向: - -- 先把 sampler output 按 sequence group 拆好 -- 再依次更新每个 `SequenceGroup` - - -### 11.2 v1 的回填方式 - -`v1` 更偏向 request-centric: - -- scheduler 输出 `num_scheduled_tokens` -- worker 返回 `req_id_to_index` -- scheduler 用 `req_id_to_index` 查表回填 - -所以: - -- `v0` 更像“按 sequence group 顺序回填” -- `v1` 更像“按 req_id 显式映射回填” - -但本质是一样的: - -> 执行 batch 在 GPU 侧可以重排、压平、做优化; -> 但 step 结束后必须有一套稳定映射,把输出放回正确的 request。 - - -## 12. 推荐源码入口 - -下面给出一份更适合顺着看的源码索引。 - -### 12.1 v1 主线 - -#### 1. 调度输出结构 - -- `vllm/v1/core/sched/output.py` - - `NewRequestData` - - `CachedRequestData` - - `SchedulerOutput` - -建议先看它,因为它定义了 scheduler 究竟在给 worker 发送什么。 - - -#### 2. scheduler 主入口 - -- `vllm/v1/core/sched/scheduler.py` - - `Scheduler.schedule()`,约 `348` - - `_make_cached_request_data()`,约 `1055` - - `update_from_output()`,约 `1302` - -这是最核心的一组函数。 - -尤其推荐按下面顺序看: - -1. `schedule()` -2. `_make_cached_request_data()` -3. `update_from_output()` - - -#### 3. engine step - -- `vllm/v1/engine/core.py` - - `EngineCore.step()`,约 `380` - -这能把高层链路串起来: - -```text -schedule - -> execute_model - -> update_from_output -``` - - -#### 4. worker 侧 batch 组装 - -- `vllm/v1/worker/gpu/model_runner.py` - - `add_requests()`,约 `612` - - `update_requests()`,约 `657` - - `prepare_inputs()`,约 `667` - -如果你最关心 shape,`prepare_inputs()` 是必须看的。 - -它直接体现: - -- `num_scheduled_tokens -> query_start_loc` -- `flat input_ids / positions` -- `seq_lens` -- `cu_num_logits` -- speculative decode 相关展开 - - -#### 5. 输出结构 - -- `vllm/v1/outputs.py` - - `SamplerOutput` - - `ModelRunnerOutput` - -这决定了从 GPU 回 scheduler 时到底带了哪些数据。 - - -#### 6. engine 内部输出与最终用户输出 - -- `vllm/v1/engine/__init__.py` - - `EngineCoreOutput` - - `EngineCoreOutputs` -- `vllm/v1/engine/output_processor.py` - - `RequestState.make_request_output()`,约 `269` - - `OutputProcessor.process_outputs()`,约 `572` - -这部分更偏“回填后的用户接口层”。 - - -### 12.2 旧版 v0 参考线 - -- `vllm/engine/llm_engine.py` - - `_process_model_outputs(...)`,约 `510` - - `step()`,约 `557` - -适合在下面两种情况下参考: - -- 你看到旧文档 / issue 还在讲 `SequenceGroup` -- 你想对照理解 vLLM 是怎样从旧结构演化到 `v1` 的 - - -## 13. 看源码时建议抓住的 5 个问题 - -如果你在调试 continuous batching,建议始终围绕下面几个问题读代码。 - -### 13.1 本轮到底调度了哪些 request - -看: - -- `num_scheduled_tokens` -- `scheduled_new_reqs` -- `scheduled_cached_reqs` - - -### 13.2 每个 request 本轮前进了多少 token - -看: - -- `n_i = num_scheduled_tokens[req_id]` - - -### 13.3 扁平 batch 的边界在哪里 - -看: - -- `query_start_loc` -- `seq_lens` - - -### 13.4 输出怎么知道属于哪个 request - -看: - -- `req_id_to_index` -- `sampled_token_ids[idx]` - - -### 13.5 request 状态什么时候推进,什么时候修正 - -看: - -- schedule 后 `num_computed_tokens` 的推进 -- speculative decode 拒绝后在 `update_from_output()` 中的修正 - - -## 14. 常见误区 - -### 14.1 “一个 batch 就是一个 `input_ids` 矩阵” - -不对。 - -vLLM 更接近: - -```text -input_ids[T] - + positions[T] - + query_start_loc[B+1] - + seq_lens[B] - + block_tables - + slot_mapping -``` - - -### 14.2 “decode 只输入 1 个 token,所以计算很简单” - -不对。 - -decode 本轮虽然只新输入 1 个 token id,但它会读取整条历史序列对应的 KV cache, -真正的上下文并没有消失。 - - -### 14.3 “本轮调度了几个 token,就一定返回几个 token” - -不对。 - -要始终区分: - -- 计算的 token 数 `n_i` -- 生成并回填的 token 数 `g_i` - - -### 14.4 “prefill 和 decode 是两套完全不同的调度器” - -不对。 - -在 vLLM 的设计里,它们更像是同一个“按 backlog 前进”的调度框架下的不同常见情形。 - - -## 15. 最终总结 - -从实现上看,`vLLM continuous batching` 的关键可以归纳成下面几句话: - -- scheduler 的核心决策单位不是“固定长度序列”,而是“本步每个 request 前进多少 token” -- worker 的执行核心不是规则的 `B x L`,而是 flat token batch `[T]` -- `query_start_loc / seq_lens / block_tables / slot_mapping` 决定了这些 token 如何映射回各自 request 与 KV cache -- step 结束后,`req_id_to_index` 负责把输出准确拆回 request -- `n_i` 与 `g_i` 不一定相等,这一点对理解 chunked prefill 与 speculative decode 非常重要 - -所以,continuous batching 真正连续流动的不是“一个静态矩阵”,而是: - -- request 集合在流动 -- 每步前进 token 数在流动 -- GPU token batch 的形状在流动 -- 完成与新加入的 request 在每个 step 都会重新重组 - -这正是它能在在线 serving 中同时兼顾: - -- 吞吐 -- 低延迟 -- 动态请求混合 - -的根本原因。 diff --git a/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md b/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md deleted file mode 100644 index 3250f331d..000000000 --- a/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md +++ /dev/null @@ -1,734 +0,0 @@ -# SGLang CUDAGraph、Prefill/Decode 与 Attention Metadata 说明 - -## 文档目的 - -这篇文档专门回答下面几个问题: - -- `CUDAGraph` 在 `SGLang` 里到底固定了什么 -- 为什么 `decode` 更适合做 `CUDAGraph` -- 为什么普通 `prefill / extend` 的 attention 很难复用同一张 graph -- `SGLang` 在 `decode` 阶段是怎样做 `capture / replay` 的 -- `attention backend` 的 `ForwardMetadata` 在 graph capture / replay 中扮演什么角色 -- 在 `ATOM plugin` 模式下,哪些 graph-only bug 容易出现,为什么 - - -## 一句话结论 - -最重要的结论先说: - -- `CUDAGraph` 固定的不是“某个 Python 函数调用”,而是一整段已经展开好的 CUDA 执行脚本 -- `decode` 更适合 graph,不是因为它没有 metadata,而是因为它的 **query 结构、token 数、kernel 形状、workspace 结构** 更稳定 -- 普通 `prefill / extend` 难 graph,不是因为 kernel 不能读不同的 metadata,而是因为 metadata 往往不只是“输入数据”,而是会影响: - - 走哪条代码路径 - - 中间 tensor 分配多大 - - gather 后张量有多长 - - workspace 形状是什么 - - 最终 launch 的 kernel 形态是什么 -- 换句话说: - - `decode` 下,metadata 更像 **kernel 参数** - - `prefill` 下,metadata 更像 **图结构控制器** - - -## 1. CUDAGraph 真正固定的是什么 - -很多人第一次接触 `CUDAGraph` 时,会误以为它只是“把一次 forward 缓存起来”。 - -更准确地说,`CUDAGraph` capture 固定的是: - -- 这次 forward 里实际 launch 了哪些 CUDA kernel -- kernel 的调用顺序 -- 每个 kernel 看到的 tensor shape / stride -- 这些 tensor 和 workspace 的内存地址 -- Python 层已经展开后的控制流分支 - -因此: - -- **tensor 的值**可以变 -- 但 **shape / 地址 / 分支 / launch 计划** 最好不要变 - -可以把它想成: - -- eager 模式像“每次现写一遍执行计划” -- cuda graph 像“录下这次执行计划,以后按原样回放” - - -## 2. 三个最容易混淆的量:`raw_bs`、`bs`、`num_tokens` - -理解 graph 之前,必须先区分三个量: - -- `raw_bs` - - 当前真实 batch 里有多少个 request -- `bs` - - 当前 replay 选中的 graph bucket 大小 -- `num_tokens` - - 这次真正传给很多 layer 的 token 数 - -它们经常不相等。 - -### 2.1 `raw_bs` - -这是 scheduler 当前真实调度出来的 request 数。 - -例如: - -- 真实只有 3 个 request 要做 decode -- 那么 `raw_bs = 3` - -### 2.2 `bs` - -这是 graph 系统为了复用固定 shape,选中的 capture bucket。 - -例如 capture 过这些 bucket: - -- `[1, 2, 4, 8, 16, 32, 48]` - -如果这次真实请求数是 3,系统可能会选择: - -- `bs = 4` - -然后: - -- 前 3 个位置放真实请求 -- 第 4 个位置放 padding / fill value - -### 2.3 `num_tokens` - -这不是永远等于 `bs`。 - -它取决于当前模式下“每个 request 本轮贡献多少 query token”。 - -几个典型场景: - -| 场景 | `num_tokens` | -|------|--------------| -| 普通 decode | `bs * 1` | -| target verify | `bs * num_draft_tokens` | -| draft decode | `bs * topk` | -| draft extend | `bs * (speculative_num_steps + 1)` | -| 普通 prefill / extend | 通常是 `sum(extend_seq_lens)`,不一定等于 `bs * 常数` | - -这也是为什么: - -- 很多 layer 看到的输入 shape 是 `[num_tokens, hidden_size]` -- 而 graph bucket 却还是按 `bs` 来管理 - - -## 3. SGLang 为什么默认把 graph 重点放在 decode - -`SGLang` 的通用 `CudaGraphRunner` 默认 capture 的 forward mode 是 `DECODE`: - -- 初始化时先设: - - `capture_forward_mode = ForwardMode.DECODE` - - `num_tokens_per_bs = 1` -- 若是 speculative target verify,再切成: - - `ForwardMode.TARGET_VERIFY` - - `num_tokens_per_bs = speculative_num_draft_tokens` -- 若是 `DLLM_EXTEND`,再切成 block-size 固定模式 - -关键点是: - -- 这些模式都满足“每个 request 贡献固定个数的 query token” - -而普通 `prefill / extend` 不满足这一点。 - -从 `sglang/python/sglang/srt/model_executor/cuda_graph_runner.py` 可以直接看到这件事: - -- graph runner 默认按 `DECODE` 组织 -- `num_tokens_per_bs` 是固定常数 -- 再用它去算: - - `max_bs` - - `max_num_token` - - 静态输入 buffer 的大小 - - -## 4. 为什么 decode 比 prefill 更适合 graph - -### 4.1 decode 的 query 结构稳定 - -decode 下,一个 request 往往只算一个 query token。 - -因此: - -- `max_q_len` 通常固定为 `1` -- `num_tokens = bs` -- `qo_indptr` 结构非常规则 -- kernel 形状更容易随 `bs bucket` 固定下来 - -即使 metadata 中像: - -- `kv_indptr` -- `kv_indices` -- `kv_last_page_len` - -这些值每轮都变,它们大多数时候也只是: - -- 作为固定 kernel 的输入索引参数 - -而不是决定“这次图长什么样”。 - -### 4.2 prefill / extend 的 query 结构是 ragged 的 - -prefill / extend 下,每个 request 这一轮要处理多少 query token,通常不一样。 - -例如: - -- request A 新增 3 个 token -- request B 新增 17 个 token -- request C 新增 1 个 token - -这时: - -- `num_tokens = 3 + 17 + 1` -- `qo_indptr` 随分布变化 -- `max_q_len` 随分布变化 -- `max_kv_len` 也随上下文长度变化 - -这不是简单的“值不同”,而是 batch 的 **几何结构** 不同。 - - -## 5. 为什么“metadata 改变”会阻碍 prefill graph 复用 - -这个问题最容易被误解。 - -### 5.1 先说清楚:metadata 变,不一定阻碍 capture - -如果你拿某一个固定的 prefill batch 去做 capture,这次 capture 可能是成功的。 - -因为那一刻: - -- `q.shape` -- `kv_indices.shape` -- `qo_indptr` -- `max_q_len` -- `max_kv_len` - -都是确定的。 - -所以问题不在“这次能不能录下来”,而在: - -- **下一次不同的 prefill batch 还能不能 replay 这张图** - -### 5.2 decode 中 metadata 更像“数据” - -decode 中,metadata 变化通常只是: - -- 不同 request 对应不同 KV 索引 -- 不同 request 当前上下文长度不同 - -但最终仍然是在执行同一类 decode kernel。 - -所以它们更像: - -- 同一张图里的输入数据 - -### 5.3 prefill 中 metadata 更像“图结构控制器” - -在普通 MLA extend / prefill 里,metadata 会直接影响: - -1. 走哪条 Python 分支 -2. 中间张量 shape -3. workspace 大小 -4. gather 结果长度 -5. kernel 的 `max_q_len / max_kv_len` - -这就是根本区别。 - - -## 6. 用 ATOM plugin 的 MLA extend 代码看这个问题 - -`ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` 里的 `MLA extend` 很能说明问题。 - -### 6.1 metadata 先决定本轮的 ragged 结构 - -普通 MLA extend 初始化时会根据当前 batch 更新: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `max_q_len` -- `max_kv_len` - -它们都不是常量,而是来自当前 batch 的 `extend_seq_lens / seq_lens`。 - -### 6.2 metadata 决定走哪条代码路径 - -在 `_forward_extend_mla_normal()` 里,代码会根据 prefix 情况和 cache 形态走不同分支: - -- 无 prefix -- 有 prefix 且要 decompress -- 有 prefix 且走 absorbed MLA - -这意味着: - -- 不同 batch 可能走完全不同的子函数 -- graph capture 录下来的并不是“抽象 extend”,而是“某一条具体 extend 分支” - -如果第一次 capture 时: - -- `extend_no_prefix = True` - -那录下来的是 `_extend_mla_no_prefix()` 这条图。 - -下一次如果: - -- `extend_no_prefix = False` - -且要走 `_extend_mla_absorbed_prefix()`,那已经不是同一张图。 - -### 6.3 metadata 决定中间 tensor 的 shape - -在 no-prefix prefill 路径里,会根据当前 query token 总数构造: - -- `temp_kv_indices` -- `output` - -它们的 shape 直接依赖: - -- `q.shape[0]` -- `total_s` - -而 `q.shape[0]` 本身就是当前 ragged batch 展平后的 token 数。 - -换一个 prefill batch: - -- `total_s` 变了 -- 中间 tensor shape 跟着变 - -那 graph 也就不再可复用。 - -### 6.4 metadata 决定 workspace 的 shape - -FP8 prefill 路径里,还会根据: - -- `reduce_partial_map.size(0)` -- `total_s` - -分配: - -- `logits` -- `attn_lse` -- `final_lse` -- `output` - -而 `reduce_partial_map` 正是从当前 batch 的分段结构推出来的。 - -所以这不是“kernel 读不同 metadata”这么简单,而是: - -- metadata 直接控制要分配多大的临时缓冲区 - -### 6.5 metadata 决定 gather 后张量长度 - -在 absorbed prefix 路径里,会先: - -- `k_selected = torch.index_select(K_Buffer, 0, kv_indices)` - -这里 `k_selected.shape[0]` 就等于: - -- `len(kv_indices)` - -而 `kv_indices` 的长度也是当前 batch 的结构量。 - -因此: - -- prefix KV gather 后的张量 shape 也会跟 batch 变化 - -这会继续向下游 kernel 传播。 - - -## 7. 为什么不能简单靠 padding 解决普通 prefill - -有人会自然想到: - -- 既然 decode 能靠 bucket + padding 做 graph -- 那 prefill 也可以 pad 到固定 `bs / max_q_len / max_kv_len` - -理论上不是完全不行,但工程上代价很大。 - -### 7.1 decode 的 padding 成本小 - -decode 一般每个 request 只处理一个 token。 - -所以即使: - -- `raw_bs = 3` -- `bs = 4` - -多 pad 一个 request 的成本也比较低。 - -### 7.2 prefill 的 padding 成本会放大 attention 计算 - -prefill attention 的成本接近: - -- query token 数 -- context 长度 -- ragged 结构 - -的组合增长。 - -如果为了 graph,把所有 request 都 pad 成: - -- 大 `max_q_len` -- 大 `max_kv_len` - -那么: - -- 无效 token 也要参与很多 attention 计算 -- mask / metadata 也会跟着变大 -- workspace 和显存开销也会膨胀 - -最后可能: - -- graph 省下来的 launch 开销 -- 远远抵不过 padding 带来的额外 attention FLOPs - - -## 8. 为什么 `TARGET_VERIFY` / `DRAFT_EXTEND` 又能 graph - -因为它们虽然也不是普通 decode,但仍然满足: - -- 每个 request 的 query token 数是固定常数 - -例如: - -- `TARGET_VERIFY` - - 每个 request 验证 `num_draft_tokens` 个 token -- `DRAFT_EXTEND` - - 每个 request 固定处理 `speculative_num_steps + 1` 个 token - -所以它们仍然可以用: - -- `bs bucket` -- `num_tokens_per_bs` - -来组织 graph。 - -换句话说: - -- 它们不是“完全自由的 ragged prefill” -- 而是“固定 token-per-request 的特殊 extend” - -因此 graph 化难度明显低于普通 prefill。 - - -## 9. SGLang 在 decode 阶段怎样做 CUDAGraph capture - -下面按实际代码链路讲。 - -### 9.1 第一步:决定 capture 模式和 bucket - -`CudaGraphRunner.__init__()` 中会: - -1. 设定 `capture_forward_mode` -2. 设定 `num_tokens_per_bs` -3. 通过 `get_batch_sizes_to_capture()` 得到 `capture_bs` -4. 算出: - - `max_bs` - - `max_num_token = max_bs * num_tokens_per_bs` - -这一步的意义是: - -- graph 系统先把“这类 forward 的形状规则”固定下来 -- 然后再一次性分配足够大的静态 buffer - -### 9.2 第二步:attention backend 先分配 graph 专用静态状态 - -接着会调用: - -- `attn_backend.init_cuda_graph_state(max_bs, max_num_token)` - -在 `ATOMAttnBackendForSgl` 里,这一步会分配 graph 期间复用的持久 buffer,例如: - -- `cuda_graph_kv_last_page_len` -- `cuda_graph_kv_indices` -- `page_table` -- `seq_lens` -- MLA decode 的 `work_metadata / work_indptr / work_info_set / reduce_*` - -这里的关键思想是: - -- graph replay 期间,不再频繁新建这些结构 -- 而是在固定 buffer 上反复更新其内容 - -### 9.3 第三步:为某个具体 bucket 构造静态输入视图 - -在 `capture_one_batch_size(bs)` 中,会从大 buffer 上切出本 bucket 对应的视图,例如: - -- `input_ids = buffers.input_ids[:num_tokens]` -- `req_pool_indices = buffers.req_pool_indices[:bs]` -- `seq_lens = buffers.seq_lens[:bs]` -- `positions = buffers.positions[:num_tokens]` - -然后构造一个 `ForwardBatch`: - -- `forward_mode = capture_forward_mode` -- `batch_size = bs` -- 大部分字段都直接绑定到这些静态 buffer 视图上 - -### 9.4 第四步:capture 前先初始化 attention metadata - -在真正 `graph capture` 之前,先调用: - -- `attn_backend.init_forward_metadata_capture_cuda_graph(...)` - -对 decode 来说,这一步本质上是: - -- 根据当前 `req_pool_indices / seq_lens` -- 把 `ForwardMetadata` 组装到 graph 专用的静态 buffer 视图上 - -对于 MLA decode,它会构造: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `work_metadata / work_indptr / work_info_set / reduce_*` - -而这些对象大多来自: - -- graph state 中预先分配好的持久 buffer - -### 9.5 第五步:跑几次 warmup,再进入真正 graph capture - -`capture_one_batch_size()` 里会: - -1. 先调用几次 `run_once()` -2. 再进入 `torch.cuda.graph(...)` -3. 把这次 forward 录下来 - -在这个过程中: - -- 输入 buffer 地址固定 -- metadata buffer 地址固定 -- forward mode 固定 -- kernel 形状固定 - -于是得到一张与 bucket `bs` 绑定的 graph。 - - -## 10. SGLang 在 decode 阶段怎样做 replay - -### 10.1 先从真实 batch 选一个 bucket - -在 `replay_prepare()` 中: - -1. 读取真实 batch 的: - - `raw_bs` - - `raw_num_token` -2. 从 `capture_bs` 中找一个: - - `bs >= raw_bs` - -这一步就是把真实 batch 映射到 graph bucket。 - -### 10.2 把真实数据 copy 到静态 buffer 的前缀 - -调用: - -- `buffers.populate_from_forward_batch(...)` - -会把真实 batch 的内容写入静态 buffer 的前缀区域,例如: - -- `input_ids[:raw_num_token]` -- `req_pool_indices[:raw_bs]` -- `seq_lens[:raw_bs]` -- `positions[:raw_num_token]` - -如果 `bs != raw_bs`,还会: - -- 用 fill value / zero 对后面的 padding 段做补齐 - -### 10.3 replay 前重建本轮 metadata - -随后调用: - -- `attn_backend.init_forward_metadata_replay_cuda_graph(...)` - -注意这一步非常关键: - -- graph replay 不是复用 capture 当时的 metadata 值 -- 而是复用 **metadata 的静态 buffer 与构造方式** -- 然后把本轮真实 batch 的索引内容重新写进去 - -也就是说: - -- 地址固定 -- 内容可变 - -对 decode 来说,这正是 graph 友好的做法。 - -### 10.4 最后 `graph.replay()` - -当静态 buffer 和 metadata 都准备好后,就直接: - -- `self.graphs[graph_key].replay()` - -执行那张已经 capture 好的图。 - -输出拿到后,再按照: - -- `raw_bs` -- `raw_num_token` - -把 padding 的尾部裁掉。 - - -## 11. Attention Metadata 在 graph capture / replay 中的角色 - -可以把 `ForwardMetadata` 在 graph 中的角色概括成一句话: - -- 它不是 graph 外的额外说明书 -- 它是 graph 里 attention kernel 的直接输入 - -但它在不同模式下的“地位”不同。 - -### 11.1 decode 中:metadata 更像固定地址上的输入参数 - -decode graph 下,像: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` - -更多是在表达: - -- 当前 batch 的 KV 可见范围 -- 当前 batch 的 query 分段 - -这些值会变,但: - -- 它们所在的 buffer 地址固定 -- 它们的 shape 规则受 bucket 控制 -- 下游还是同一类 decode kernel - -所以 graph 可以复用。 - -### 11.2 prefill 中:metadata 会升级成“图结构的一部分” - -prefill / extend 下,metadata 往往不仅仅被 kernel 读取,还会影响: - -- 选择哪条路径 -- 构造哪些中间 tensor -- 中间 tensor 有多大 -- workspace 有多大 -- kernel 看到的 `max_q_len / max_kv_len` - -因此: - -- metadata 变化会把“图长什么样”一起改掉 - -这就是它阻碍 graph 复用的根本原因。 - - -## 12. `ATOMAttnBackendForSgl` 中 graph metadata 的几个关键点 - -### 12.1 `init_cuda_graph_state()`:先分配 graph 专用持久 buffer - -plugin backend 里专门分配了: - -- `cuda_graph_kv_last_page_len` -- `cuda_graph_kv_indices` -- `page_table` -- `seq_lens` -- MLA decode 的 persistent workspace - -这样 replay 时就能复用这些地址。 - -### 12.2 `init_forward_metadata_capture_cuda_graph()`:把 bucket 数据写成 metadata - -这一步会根据当前 mode 做不同初始化: - -- `decode_or_idle` -- `target_verify` -- `draft_extend` - -每种模式都把: - -- bucket 对应的 `bs` -- 固定的 `num_tokens_per_bs` -- 当前 request 索引和 seq_lens - -转成 kernel 需要的 metadata。 - -### 12.3 `init_forward_metadata_replay_cuda_graph()`:重建本轮 metadata - -replay 时,plugin backend 不会继续沿用 capture 时那一轮的 metadata 值,而是: - -- 在固定 graph buffer 上 -- 根据本轮真实 batch 重建一次 metadata - -这一步必须非常小心“当前 bucket 视图”和“整块静态 buffer”的区别。 - -最近在 debug 中出现的一个典型 graph-only bug 正是: - -- 上游 replay 某条 speculative 路径把整块静态 buffer 传下来 -- plugin backend 按“已经是当前 `bs` 视图”去理解 -- 于是出现: - - `bs = 1` - - `seq_lens.shape[0] = 48` - -后来在 plugin backend 里做了统一切片规整: - -- `req_pool_indices = req_pool_indices[:bs]` -- `seq_lens = seq_lens[:bs]` -- `seq_lens_cpu = seq_lens_cpu[:bs]` - -本质上就是把 replay 的输入重新对齐到“当前 bucket 视图”。 - - -## 13. 这次 debug 暴露出的两个 graph-only 经验 - -### 13.1 backend 选型必须真的落到 plugin backend - -之前 `kv_last_page_len` 掉到 CPU 的问题,最后定位到: - -- `AiterMultiStepDraftBackend` 内部直接实例化 `AiterAttnBackend` -- 绕过了 plugin 通过 registry 注册的 `"aiter" -> ATOMAttnBackendForSgl` - -这说明: - -- graph-only 路径里,某些 backend 可能不是从常规 registry 路径拿到的 -- 如果 direct construction 没 patch 到,graph state 就可能偷偷回落到 upstream 实现 - -### 13.2 replay 必须明确区分“静态大 buffer”和“当前 bucket 视图” - -graph replay 中,静态 buffer 通常按: - -- `max_bs` -- `max_num_token` - -一次性分配。 - -但 backend 在构 metadata 时,真正应该看到的是: - -- 当前 bucket 的前 `bs` -- 当前 token 的前 `num_tokens` - -一旦把整块静态 buffer 当成当前视图使用,就很容易出现: - -- shape mismatch -- CPU / CUDA tensor 混用 -- metadata 与实际 batch 不一致 - - -## 14. 用一句工程化的话总结 - -如果只用一句最工程化的话来总结这篇文档: - -- `decode` 图里,metadata 大多是 **固定形状 graph 的输入数据** -- `prefill` 图里,metadata 往往会变成 **决定图形状和执行路径的结构量** - -因此: - -- `decode` 适合用 bucket + padding + 静态 buffer 做 graph -- 普通 `prefill / extend` 则很难在收益合理的前提下复用同一张 graph - - -## 15. 最后总结 - -记住下面五句话就够了: - -1. `CUDAGraph` 固定的是一整段具体 CUDA 执行计划,不只是 Python 函数入口。 -2. `raw_bs` 是真实请求数,`bs` 是 graph bucket,`num_tokens` 是真正传给很多 layer 的 token 数。 -3. `decode` 更适合 graph,因为每个 request 的 query 结构更稳定,metadata 更像输入参数。 -4. 普通 `prefill / extend` 难 graph,因为 metadata 会影响分支、shape、workspace 和 kernel 形态,升级成图结构的一部分。 -5. 在 `SGLang + ATOM plugin` 里,graph replay 的关键不是“重复使用旧 metadata 值”,而是“在固定 buffer/地址上重建本轮 metadata 内容”。 diff --git a/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md b/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md deleted file mode 100644 index c2e4e4b6b..000000000 --- a/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md +++ /dev/null @@ -1,866 +0,0 @@ -# 最简单 Prefill、CUDAGraph 与 Metadata 速查 - -## 文档目的 - -这篇文档只回答一个收窄后的问题: - -- **不考虑不同 kernel path** -- **不考虑 prefix cache** -- **不考虑 speculative target_verify / draft_extend** -- **只考虑最普通、最简单的 prefill / extend** - -在这个前提下,说明: - -1. 为什么这种最简单的 prefill 仍然难做 `CUDAGraph` -2. attention metadata 的核心字段在这种场景下分别表示什么 -3. 给几个可以手算的小例子,方便以后速查 - - -## 一句话结论 - -即使只看最简单 prefill,`CUDAGraph` 的挑战仍然存在。根本原因不是“metadata 会变”本身,而是: - -- `total_tokens` 会变 -- `max_q_len / max_kv_len` 会变 -- ragged metadata 会跟着 batch 几何结构一起变 -- 很多中间 tensor / workspace 的 shape 也会变 - -所以问题不是: - -- kernel 能不能读取不同的 metadata 值 - -而是: - -- **同一个 prefill batch family,能不能稳定成一张固定 shape 的图** - - -## 1. 本文说的“最简单 prefill”是什么 - -这里约定的“最简单 prefill”是: - -- 没有 prefix cache -- 没有 speculative 分支 -- 不讨论不同 kernel path 的切换 -- 假定已经选中某一条固定的 prefill kernel 路径 -- 一个 batch 里有若干 request -- 每个 request 本轮需要处理若干 query token -- attention 以 ragged / varlen 形式运行 - -可以把它理解成: - -- `ForwardMode.EXTEND` -- `spec_info = None` -- `extend_prefix_lens = 0` -- attention backend 已经决定“就走这条 prefill kernel” - -本文不讨论: - -- prefix/no-prefix 的 kernel 分流 -- absorbed / decompress 等 MLA 专有分流 -- draft_extend / target_verify -- decode - - -## 2. 最简单 prefill 的数据形状 - -prefill 和 decode 最大的不同在于: - -- decode 常常是每个 request 本轮只算 1 个 token -- prefill 常常是每个 request 本轮要算多个 token - -因此,很多 layer 真正看到的不是: - -- `[bs, hidden_size]` - -而是: - -- `[total_tokens, hidden_size]` - -其中: - -- `bs` = request 数 -- `total_tokens` = 本轮所有 request 的 query token 总数 - -在最简单 prefill 下,常见关系是: - -```text -total_tokens = sum(extend_seq_lens) -``` - -这意味着: - -- 即使 `bs` 不变 -- 只要每个 request 的长度分布变了 -- `total_tokens` 就会变 - - -## 3. 为什么最简单 prefill 仍然难做 CUDAGraph - -下面只看“最简单 prefill”,不引入分支复杂度。 - -### 3.1 `q/k/v/o` 的 token 维会变 - -prefill 下最直观的问题就是: - -- `q.shape[0] = total_tokens` -- `k.shape[0] = total_tokens` -- `v.shape[0] = total_tokens` -- `o.shape[0] = total_tokens` - -只要: - -- request 数不同 -- 或每个 request 的 query 长度分布不同 - -那么: - -- `total_tokens` 就不同 -- 上面这些张量 shape 就不同 - -而 `CUDAGraph` 更喜欢的是: - -- tensor shape 固定 -- graph 中 kernel launch 形态固定 - -这已经是第一层挑战。 - -### 3.2 `max_q_len / max_kv_len` 会变 - -即使不看 `q.shape[0]`,varlen attention 往往还会显式传: - -- `max_q_len` -- `max_kv_len` -- `cu_seqlens_q` 或 `qo_indptr` - -这些量不是 decoration,而是 kernel 的核心输入。 - -例如对于一个 batch: - -- request A: 3 tokens -- request B: 2 tokens - -则: - -- `max_q_len = 3` - -如果下一个 batch 是: - -- request A: 4 tokens -- request B: 1 token - -则: - -- `max_q_len = 4` - -虽然: - -- 两个 batch 的 `bs = 2` -- 两个 batch 的 `total_tokens = 5` - -但: - -- `qo_indptr` 不同 -- `max_q_len` 不同 - -这意味着: - -- 内部 tile / launch 策略可能不同 -- workspace 需求也可能不同 - -### 3.3 ragged metadata 在描述“问题几何结构” - -在 decode 里,很多 metadata 更像: - -- 固定图上的输入索引数据 - -而在 prefill 里,metadata 往往在表达: - -- 一共有多少 query token -- 这些 query token 怎样按 request 分段 -- KV token 怎样按 request 分段 -- 当前 batch 的最大 query / KV 长度是多少 - -所以它不只是“值会变”,而是在描述: - -- **这轮 attention 问题本身长什么样** - -这就让同一张图更难复用。 - -### 3.4 中间 tensor / workspace 的 shape 也会变 - -哪怕我们强行假设: - -- kernel path 不变 - -很多中间结构也仍然可能随 batch 变化。 - -例如: - -- 某些临时索引张量长度跟 `total_tokens` 走 -- 某些 workspace 大小跟 `max_q_len / max_kv_len` 走 -- 某些 reduce buffer 大小跟分段结构走 - -所以问题不止在输入张量,而是: - -- graph 内部很多“中间物体”的 shape 也不稳定 - -### 3.5 同一个分支里也可能无法稳定 replay - -这点最容易误解。 - -即使你已经保证: - -- 一定走同一个 prefill kernel path - -也不代表可以 graph。 - -因为同一条路径里仍然可能有: - -- `total_tokens` 变化 -- `max_q_len` 变化 -- `max_kv_len` 变化 -- workspace shape 变化 - -所以: - -- “分支固定” - -并不等于: - -- “图固定” - - -## 4. Metadata 速查表 - -下面只保留最常用于“最简单 prefill”理解的字段。 - -### 4.1 高层 batch 字段 - -| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | -|------|------------|----------|-----------------------| -| `bs` | Python `int` | request 数 | 当前 batch 有几个 request | -| `extend_seq_lens` | `[bs]` | 每个 request 本轮 query token 数 | 决定 `total_tokens` 和 `qo_indptr` | -| `seq_lens` | `[bs]` | 每个 request 当前可见 KV 长度 | 决定 `kv_indptr` 和 `max_kv_len` | -| `seq_lens_sum` | Python `int` | 所有 request KV 长度总和 | 常用于辅助构造 KV metadata | -| `req_pool_indices` | `[bs]` | request 在 `req_to_token` 里的行号 | 用来从映射表里取物理 KV slot | - -### 4.2 Query 侧 metadata - -| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | -|------|------------|----------|-----------------------| -| `qo_indptr` | `[bs + 1]` | flatten 后每个 request 的 query 段边界 | 告诉 kernel 哪些 query 属于哪个 request | -| `max_q_len` | Python `int` | batch 内单 request 最大 query 长度 | kernel 的长度上限参数 | -| `total_tokens` | Python `int` | flatten 后 query token 总数 | 决定很多输入/输出第一维 | - -### 4.3 KV 侧 metadata - -| 字段 | 常见 shape | 物理意义 | 简单 prefill 下的作用 | -|------|------------|----------|-----------------------| -| `kv_indptr` | `[bs + 1]` | flatten 后每个 request 的 KV 段边界 | 告诉 kernel 每段 KV 从哪里到哪里 | -| `kv_indices` | `[sum(seq_lens)]` 或相近长度 | flatten 后每个 KV token 对应的物理 slot | 真正告诉 kernel 去读哪些物理 KV | -| `max_kv_len` | Python `int` | batch 内单 request 最大 KV 长度 | kernel 的 KV 长度上限参数 | -| `kv_last_page_len` | `[bs]` | 每个 request 最后一页有效 token 数 | paged MLA kernel 常用 | -| `kv_lens` | `[bs]` | 每个 request 当前 KV 长度 | 在 page-table 表达里常用 | -| `page_table` | `[bs, max_pages]` | request 到 page id 的二维映射 | 非 MLA / page-table 风格 backend 常用 | - - -## 5. 这些字段的物理意义,最简单地怎么记 - -### 5.1 `qo_indptr` - -记法: - -- 它是 query 侧的 CSR 前缀和边界表 - -典型 shape: - -- `[bs + 1]` - -dtype: - -- 通常是 `int32` - -含义: - -- 第 `i` 个 request 的 query 在 flatten Q 中的范围是: - - `[qo_indptr[i], qo_indptr[i+1])` - -和哪些量对应: - -- `qo_indptr[0]` 固定是 `0` -- `qo_indptr[-1]` 通常等于: - - `total_tokens` -- `qo_indptr[i + 1] - qo_indptr[i]` 等于: - - 第 `i` 个 request 的 query 长度 - -### 5.2 `kv_indptr` - -记法: - -- 它是 KV 侧的 CSR 前缀和边界表 - -典型 shape: - -- `[bs + 1]` - -dtype: - -- 通常是 `int32` - -含义: - -- 第 `i` 个 request 的 KV 在 flatten `kv_indices` 中的范围是: - - `[kv_indptr[i], kv_indptr[i+1])` - -和哪些量对应: - -- `kv_indptr[0]` 固定是 `0` -- `kv_indptr[-1]` 通常等于: - - `len(kv_indices)` -- `kv_indptr[i + 1] - kv_indptr[i]` 等于: - - 第 `i` 个 request 当前参与 attention 的 KV 长度 - -### 5.3 `kv_indices` - -记法: - -- 它是“这次 attention 真正要访问的物理 KV slot 列表” - -典型 shape: - -- `[sum(seq_lens)]` -- 更严格一点说: - - `[kv_indptr[-1]]` - -dtype: - -- 通常是 `int32` - -含义: - -- 每个元素都是一个 physical KV slot id - -更具体一点: - -- `kv_indices` 不是“第几个 token” -- 也不是“第几个 request” -- 它是: - - **flatten 后,每个 KV token 在物理 KV cache 里的实际位置** - -它和下面几个量要一起看: - -- `req_pool_indices` - - shape 通常是 `[bs]` - - 告诉你“当前 batch 里每个 request 对应 `req_to_token` 的哪一行” -- `req_to_token` - - shape 通常是 `[req_pool_size, max_context_len]` - - 告诉你“这个 request 的逻辑第 `j` 个 token,物理上写在 KV cache 的哪个 slot” -- `seq_lens` - - shape 通常是 `[bs]` - - 告诉你“这个 request 当前有多少个 KV token 参与 attention” -- `kv_indptr` - - shape 通常是 `[bs + 1]` - - 告诉你“这个 request 对应的 KV 段,在 flatten 后 `kv_indices` 里的哪一段” - -所以可以把 `kv_indices` 理解成: - -- 先按 `req_pool_indices` 找到每个 request 在 `req_to_token` 中的那一行 -- 再按 `seq_lens[i]` 取出这行前面的有效 token 映射 -- 最后把所有 request 的映射段拼接起来 - -也就是说: - -- `kv_indptr` 负责“分段边界” -- `kv_indices` 负责“段内具体有哪些 physical slot” - -### 5.3.1 它为什么重要 - -attention kernel 真正关心的不是: - -- “这是第几个逻辑 token” - -而是: - -- “要去 KV cache 的哪个物理位置读 K/V” - -`kv_indices` 正是在回答这个问题。 - -如果没有 `kv_indices`,kernel 只知道: - -- batch 里有几个 request -- 每个 request 长度是多少 - -但它仍然不知道: - -- 这些 request 的历史 token 到底落在 KV cache 里的哪些 physical slot 上 - -### 5.3.2 它为什么通常是 flatten 的 - -`kv_indices` 做成一维 flatten 形式,而不是二维 `[bs, max_kv_len]`,是因为: - -- 不同 request 的 KV 长度不一样 -- ragged attention 更自然的表示法就是: - - 一条长数组 - - 再配一个 `kv_indptr` - -这和 CSR 稀疏矩阵的表达方式很像: - -- `kv_indices` = 数据主体 -- `kv_indptr` = 每段边界 - -### 5.3.3 它和 `total_tokens` / `max_kv_len` 的区别 - -这几个量很容易混: - -- `total_tokens` - - shape 是标量 / Python `int` - - query 侧总 token 数 -- `max_kv_len` - - shape 是标量 / Python `int` - - 单 request 最大 KV 长度 -- `kv_indices` - - shape 是一维张量 `[sum(seq_lens)]` - - 这轮 attention 真正要访问的所有 physical KV slot 列表 - -它们不是一回事。 - -例如: - -- `bs = 2` -- `seq_lens = [3, 2]` - -那么: - -- `max_kv_len = 3` -- `len(kv_indices) = 5` - -前者是“最大段长度”,后者是“所有段拼起来后的总长度”。 - -### 5.3.4 一个更完整的手算例子 - -假设: - -- `req_pool_indices = [7, 9]` - - shape: `[2]` -- `seq_lens = [5, 3]` - - shape: `[2]` -- `req_to_token[7, 0:5] = [100, 101, 102, 103, 120]` -- `req_to_token[9, 0:3] = [200, 201, 220]` - -那么先算边界: - -```text -kv_indptr = [0, 5, 8] -``` - -它的 shape 是: - -- `[3]`,也就是 `[bs + 1]` - -再按每个 request 的有效长度取映射: - -- request 0 取: - - `[100, 101, 102, 103, 120]` -- request 1 取: - - `[200, 201, 220]` - -最后拼接得到: - -```text -kv_indices = [100, 101, 102, 103, 120, 200, 201, 220] -``` - -它的 shape 是: - -- `[8]` -- 也就是: - - `[sum(seq_lens)] = [5 + 3]` - -于是: - -- request 0 的 KV 段是: - - `kv_indices[kv_indptr[0]:kv_indptr[1]]` - - 也就是 `kv_indices[0:5]` -- request 1 的 KV 段是: - - `kv_indices[kv_indptr[1]:kv_indptr[2]]` - - 也就是 `kv_indices[5:8]` - -### 5.3.5 debug 时怎么看 `kv_indices` - -如果你在 debug attention metadata,`kv_indices` 最值得看两件事: - -1. 长度对不对 - -- 在最简单 prefill 里,通常应该有: - - `len(kv_indices) == sum(seq_lens)` -- 也可以写成: - - `kv_indices.shape == (int(seq_lens.sum()),)` - -如果这个关系都不对,说明: - -- `kv_indptr` -- `seq_lens` -- 或 `req_to_token` 的使用 - -有地方没对齐。 - -2. 分段内容对不对 - -给定: - -- `kv_indptr` - - shape: `[bs + 1]` -- `kv_indices` - - shape: `[sum(seq_lens)]` - -你应该能把每个 request 对应的 physical slot 段切出来,并和: - -- `req_to_token[row, :seq_len]` - -一一对应上。 - -如果切出来的段和 `req_to_token` 不对应,常见意味着: - -- `req_pool_indices` 行号不对 -- `seq_lens` 不是这轮应看的 KV 长度 -- 或者 graph replay 时把“整块静态 buffer”错当成了“当前 bucket 视图” - -### 5.4 `max_q_len` - -记法: - -- 这轮 batch 中,单 request 最长的 query 长度 - -shape: - -- 标量 / Python `int` - -它不是: - -- 所有 query 总数 - -### 5.5 `max_kv_len` - -记法: - -- 这轮 batch 中,单 request 最长的 KV 长度 - -shape: - -- 标量 / Python `int` - - -## 6. 最简单 prefill 的三个小例子 - -### 例子 1:单 request prefill - -假设: - -- `bs = 1` -- `extend_seq_lens = [5]` - - shape: `[1]` -- `seq_lens = [5]` - - shape: `[1]` - -那么: - -- `total_tokens = 5` -- `qo_indptr = [0, 5]` - - shape: `[2]` -- `kv_indptr = [0, 5]` - - shape: `[2]` -- `max_q_len = 5` -- `max_kv_len = 5` - -如果 `req_to_token[row]` 对应的是: - -- `[100, 101, 102, 103, 104]` - -那么: - -- `kv_indices = [100, 101, 102, 103, 104]` - - shape: `[5]` - -这个例子很简单,但也正好说明: - -- 这轮 graph 里主干很多 tensor 第一维都是 `5` - -如果下一个 batch 变成 `7` 个 token: - -- 图里的很多 shape 就都要变 - -### 例子 2:两个 request,长度不同 - -假设: - -- `bs = 2` -- `extend_seq_lens = [3, 2]` - - shape: `[2]` -- `seq_lens = [3, 2]` - - shape: `[2]` - -那么: - -- `total_tokens = 5` -- `qo_indptr = [0, 3, 5]` - - shape: `[3]` -- `kv_indptr = [0, 3, 5]` - - shape: `[3]` -- `max_q_len = 3` -- `max_kv_len = 3` - -如果: - -- request 0 的物理 slot 是 `[10, 11, 12]` -- request 1 的物理 slot 是 `[20, 21]` - -那么: - -- `kv_indices = [10, 11, 12, 20, 21]` - - shape: `[5]` - -这里最值得注意的是: - -- `total_tokens = 5` -- 但 request 分段结构已经不是均匀的 - -### 例子 2.1:把 `qo_indptr + kv_indptr + kv_indices` 放在一起看 - -继续沿用上面的 batch: - -- `bs = 2` -- `extend_seq_lens = [3, 2]` -- `seq_lens = [3, 2]` -- `qo_indptr = [0, 3, 5]` -- `kv_indptr = [0, 3, 5]` -- `kv_indices = [10, 11, 12, 20, 21]` - -如果把 flatten 后的 Q token 记成: - -```text -Q_flat = [q0, q1, q2, q3, q4] -``` - -那么 query 侧分段是: - -- request 0: - - `Q_flat[0:3]` - - 也就是 `q0, q1, q2` -- request 1: - - `Q_flat[3:5]` - - 也就是 `q3, q4` - -因为: - -```text -qo_indptr = [0, 3, 5] -``` - -同样,KV 侧分段是: - -- request 0: - - `kv_indices[0:3]` - - 也就是 `[10, 11, 12]` -- request 1: - - `kv_indices[3:5]` - - 也就是 `[20, 21]` - -因为: - -```text -kv_indptr = [0, 3, 5] -kv_indices = [10, 11, 12, 20, 21] -``` - -把它们并排看,就是: - -```text -request 0: - Q range = [qo_indptr[0], qo_indptr[1]) = [0, 3) - Q tokens = [q0, q1, q2] - KV range = [kv_indptr[0], kv_indptr[1]) = [0, 3) - KV slots = [10, 11, 12] - -request 1: - Q range = [qo_indptr[1], qo_indptr[2]) = [3, 5) - Q tokens = [q3, q4] - KV range = [kv_indptr[1], kv_indptr[2]) = [3, 5) - KV slots = [20, 21] -``` - -这就是 ragged attention metadata 最核心的意思: - -- `qo_indptr` - - 告诉 kernel:flatten 后哪些 query 属于哪个 request -- `kv_indptr` - - 告诉 kernel:flatten 后哪些 KV 段属于哪个 request -- `kv_indices` - - 告诉 kernel:这个 request 的 KV 段具体对应哪些 physical KV slot - -如果再把 `req_to_token` 写出来: - -```text -req_to_token[row_of_req0, 0:3] = [10, 11, 12] -req_to_token[row_of_req1, 0:2] = [20, 21] -``` - -那就能看到: - -- `kv_indices` - 本质上就是把每个 request 在 `req_to_token` 里的有效前缀切出来,再按 request 顺序拼起来。 - -### 例子 3:`total_tokens` 一样,但 graph 仍然难复用 - -看两个 batch: - -#### batch A - -- `bs = 2` -- `extend_seq_lens = [3, 2]` - - shape: `[2]` - -得到: - -- `total_tokens = 5` -- `qo_indptr = [0, 3, 5]` - - shape: `[3]` -- `max_q_len = 3` - -#### batch B - -- `bs = 2` -- `extend_seq_lens = [4, 1]` - - shape: `[2]` - -得到: - -- `total_tokens = 5` -- `qo_indptr = [0, 4, 5]` - - shape: `[3]` -- `max_q_len = 4` - -这两个 batch: - -- `bs` 相同 -- `total_tokens` 相同 - -但: - -- `qo_indptr` 不同 -- `max_q_len` 不同 - -这说明: - -- 即使总 token 数没变 -- prefill 的“问题几何结构”仍然变了 - -这就是 graph 复用困难的关键例子。 - -### 例子 4:为什么 decode 更容易 graph - -假设 decode: - -- `bs = 2` -- 每个 request 本轮只解 1 个 token - -那么: - -- `total_tokens = 2` -- `qo_indptr = [0, 1, 2]` - - shape: `[3]` -- `max_q_len = 1` - -下一个 batch 只要 bucket 还是这个 `bs`,即使: - -- `kv_indices` -- `seq_lens` -- `kv_indptr` - -的内容变了,graph 里主干 shape 往往还是稳定得多。 - -所以: - -- decode 中 metadata 更像“数据表” -- prefill 中 metadata 更像“几何结构描述” - - -## 7. 如果硬要对最简单 prefill 做 graph,需要什么条件 - -最少需要做下面几件事中的一些: - -### 7.1 固定 `bs` - -最基础的 bucket 化: - -- 只允许某几个 `bs` 值 - -但仅固定 `bs` 还不够。 - -### 7.2 固定 `total_tokens` - -因为很多输入/输出 tensor 的第一维是: - -- `total_tokens` - -若它不固定,graph 很难复用。 - -### 7.3 固定 `max_q_len / max_kv_len` - -因为它们常常影响: - -- kernel launch 形态 -- workspace 大小 - -### 7.4 固定 workspace 形状 - -也就是说: - -- 需要让中间临时张量有固定上限 -- 或者直接预分配到某个 bucket 上限 - -### 7.5 允许 padding / pack / unpack - -最现实的手段通常是: - -- graph 外把 ragged batch 归一化 -- graph 内只处理固定形状张量 -- graph 后再 unpad - -但代价是: - -- 额外数据搬运 -- padding 带来的无效计算 - - -## 8. 为什么这比 decode 难很多 - -可以用一句最简单的话来对比: - -- `decode` 的不确定性主要是“数据值不同” -- `prefill` 的不确定性主要是“问题结构不同” - -decode 常常可以做到: - -- 固定 `num_tokens_per_bs = 1` -- 固定 `max_q_len = 1` -- 只靠 `bs bucket` 就稳定大部分形状 - -而最简单 prefill 仍然会遇到: - -- `total_tokens` 变化 -- `qo_indptr` 变化 -- `max_q_len` 变化 -- `max_kv_len` 变化 -- 中间 workspace 变化 - - -## 9. 最后总结 - -只记下面六句话就够了: - -1. 最简单 prefill 也不是固定 shape 问题,而是 ragged / varlen 问题。 -2. `total_tokens = sum(extend_seq_lens)`,它决定了很多主干张量的第一维。 -3. `qo_indptr` 和 `kv_indptr` 不是装饰字段,而是在描述这轮 attention 的分段几何结构。 -4. `max_q_len / max_kv_len` 会随着 batch 分布变化,常常进一步影响 kernel 和 workspace。 -5. 即使不考虑不同 kernel path,prefill 仍然可能因为 shape 和 workspace 不稳定而难以复用同一张 graph。 -6. 如果真的想 graph 化最简单 prefill,通常还需要 bucket 化、padding 或 pack/unpack 来先把 ragged 问题归一化。 diff --git a/work_log/MTP/MTP-2026-04-08.md b/work_log/MTP/MTP-2026-04-08.md deleted file mode 100644 index ad368de3a..000000000 --- a/work_log/MTP/MTP-2026-04-08.md +++ /dev/null @@ -1,525 +0,0 @@ -# 2026-04-08 MTP 调研与调试记录 - -## 目标 - -本次工作的目标是调研并尝试推进 `ATOM + SGLang plugin` 路径下的 -DeepSeek MTP 接入,重点回答下面几个问题: - -- `ATOM/atom/plugin/sglang` 当前到底支持了什么 -- upstream SGLang 的 DeepSeek MTP / NextN 是怎么组织的 -- 当前启动命令实际跑起来时,target model 和 draft model 分别是谁 -- 当前失败点落在什么地方,根因是什么 -- 如果后续正式推进,推荐的技术路线是什么 - - -## 本次结论速览 - -- `ATOM sglang plugin` 当前并没有真正把 `ATOM/atom/models/deepseek_mtp.py` - 接到 draft/MTP 路径上。 -- 当前运行形态更接近: - - target model 走 `ATOM plugin wrapper + ATOM DeepseekV3ForCausalLM` - - draft model 走 upstream SGLang 的 `DeepseekV3ForCausalLMNextN` -- 换句话说,当前不是“ATOM MTP 已经接通”,而是: - - `ATOM target + SGLang NextN draft + ATOM target verify backend` -- 本次已经修掉了第一个显式接口不兼容问题: - - upstream speculative worker 需要 target model 提供 - `get_embed_and_head()/set_embed_and_head()/set_embed()` - - `ATOM plugin wrapper` 原本没有这些接口 -- 当前新的阻塞点在: - - `TARGET_VERIFY` 路径进入 `ATOM` 的 - `sgl_attn_backend.py` - - ATOM plugin 把 verify 当成普通 extend 处理 - - 于是错误访问了 `forward_batch.extend_seq_lens` - - 但在 verify 路径下这个字段本来就可能为 `None` - - -## 背景知识 - -### 1. upstream SGLang 的 DeepSeek MTP / NextN 组织方式 - -在 upstream SGLang 里,DeepSeek 的 draft/MTP 不是通过一个独立的 -`DeepSeekMTP` 类来暴露给 speculative runtime,而是通过一个 -SGLang 风格的 draft model 壳: - -- `sglang/python/sglang/srt/models/deepseek_nextn.py` -- 类名:`DeepseekV3ForCausalLMNextN` - -这层壳的特点: - -- 对外长得像标准的 `ForCausalLM` -- 能直接被 `ModelRegistry` 解析和实例化 -- 带有 `load_weights(..., is_nextn=True)` -- 带有 `get_embed_and_head()/set_embed_and_head()` -- 能直接对接 SGLang speculative worker - -它内部并不会真的再构一个完整的 target DeepSeek 模型,而是构一个 -更薄的 NextN draft 结构。 - - -### 2. ATOM 的 MTP 组织方式 - -ATOM 侧则是另一种设计: - -- `ATOM/atom/models/deepseek_mtp.py` -- 类名:`DeepSeekMTP` - -这更像一个 draft core,而不是一个完整的 SGLang 风格 runtime wrapper。 -它暴露的是: - -- `forward(input_ids, positions, hidden_states, ...)` -- `compute_logits(hidden_states, spec_step_idx=...)` - -也就是说: - -- upstream SGLang:偏“运行时壳子” -- ATOM:偏“底层 draft 模型核心” - - -### 3. 为什么这是个关键差异 - -这意味着 plugin 端后续有两种思路: - -1. 继续沿用 upstream 的思路,在 plugin 里做一个 - `DeepseekV3ForCausalLMNextN` 风格的壳 -2. 尽量复用 `ATOM/atom/models/deepseek_mtp.py`,只补一层很薄的 - SGLang 兼容 wrapper - -本次调研后的倾向是: - -- 不建议在 plugin 里再复制一整套 upstream NextN 继承链 -- 更推荐“上层保留 SGLang 兼容接口,下层复用 ATOM DeepSeekMTP” - - -### 4. speculative 运行时里的几个对象容易混淆 - -在 SGLang speculative 模式下,scheduler 里会出现多个 worker: - -- `tp_worker` - - target `TpModelWorker` -- `draft_worker` - - 变量名容易误导 - - 在 scheduler 里它其实通常是 speculative orchestrator - - 例如 `EAGLEWorker` / `EAGLEWorkerV2` -- 真正的 draft `TpModelWorker` - - 在 orchestrator 内部 - -所以: - -- `self.model_worker = self.draft_worker` - -并不是“target worker 被 draft worker 替代”,而是: - -- scheduler 把统一执行入口切到了 speculative orchestrator -- orchestrator 再内部协调: - - draft propose - - target verify - - draft extend - - -### 5. `embed_and_head` 是什么,为什么 drafter 需要 - -upstream speculative worker 在初始化 draft model 时,会从 target model 取: - -- `embed = embed_tokens.weight` -- `head = lm_head.weight` - -原因: - -- drafter 需要把 token id 变成 embedding,再继续往下算 -- drafter 也需要把 hidden state 变成 logits,提议下一个 token -- 共享 target 的 embedding / lm_head 可以: - - 节省显存 - - 保持 vocab 完全一致 - - 避免 draft 再重复加载一份大权重 - - -## 当前代码状态理解 - -### 1. ATOM plugin 当前只导出了哪些 model - -文件: - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` -- `ATOM/atom/plugin/register.py` - -当前 external package 暴露的 `_MODEL_NAMES` 只有: - -- `DeepseekV3ForCausalLM` -- `Qwen3MoeForCausalLM` - -ATOM plugin 支持的 `_ATOM_SUPPORTED_MODELS` 里也没有: - -- `DeepseekV3ForCausalLMNextN` -- `DeepSeekMTPModel` - -这意味着: - -- target `DeepseekV3ForCausalLM` 可以被 ATOM external package 覆盖 -- draft `DeepseekV3ForCausalLMNextN` 不会被 ATOM external package 覆盖 - - -### 2. 为什么 target 在 `prepare_model()` 里还是 `DeepseekV3ForCausalLM` - -文件: - -- `ATOM/atom/plugin/prepare.py` -- `sglang/python/sglang/srt/configs/model_config.py` - -我们在 `prepare_model()` 打日志时看到: - -- `model_arch in prepare_model: DeepseekV3ForCausalLM` - -这并不矛盾,因为那个 `prepare_model()` 调用发生在 target 路径。 - -而 draft 路径是另外一条 worker 初始化链,且在 `ModelConfig._config_draft_model()` -里会把: - -- `DeepseekV3ForCausalLM` - -> `DeepseekV3ForCausalLMNextN` - -所以: - -- target 看到 `DeepseekV3ForCausalLM` 正常 -- draft 并不经过同一个 `prepare_model()` 观察点 - - -### 3. 当前实际 load 的 draft module 是谁 - -从本次运行已经走到 speculative verify 阶段可以判断: - -- draft worker 已经成功创建 -- draft model 已经成功 load -- draft model 不是完全没起来 - -结合当前注册关系,最合理的判断是: - -- target model:ATOM `DeepseekV3ForCausalLM` -- draft model:upstream SGLang `DeepseekV3ForCausalLMNextN` - -不是: - -- ATOM `DeepSeekMTP` - - -## 本次实验与过程记录 - -### 实验 1:总体调研 ATOM sglang plugin 中 MTP 的现状 - -动机: - -- 先搞清楚 plugin 里到底有哪些 speculative / MLA / MTP 相关代码 -- 避免一上来就在错误层面改代码 - -主要阅读文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attention_mla.py` -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` -- `ATOM/atom/plugin/register.py` -- `ATOM/atom/plugin/prepare.py` -- `ATOM/atom/models/deepseek_mtp.py` -- `ATOM/atom/spec_decode/eagle.py` -- `sglang/python/sglang/srt/models/deepseek_nextn.py` -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - -结果: - -- plugin 侧已经部分支持 speculative-aware 的 MLA 计算路径 -- 但没有完整接入 draft/MTP model -- 真正完整的 `DeepSeekMTP` 在 ATOM 原生链路里,不在当前 plugin draft 路径里 - - -### 实验 2:确认 target / draft 两条模型构造链 - -动机: - -- 理清 scheduler 里为什么既有 target worker 又有 draft worker -- 理清为什么 target model 和 draft model 不一定走同一套 model class - -关键结论: - -- speculative 模式下,运行时确实同时存在 target model 和 draft model -- 但二者最终还是都走通用 loader / `_initialize_model()` -- 区别在于: - - target 的 `ModelConfig` 正常走原始架构 - - draft 的 `ModelConfig` 会先做 `is_draft_model=True` 的架构改写 - - -### 实验 3:首个阻塞点 - `get_embed_and_head` 缺失 - -报错: - -- `AttributeError: 'DeepseekV3ForCausalLM' object has no attribute 'get_embed_and_head'` - -发生点: - -- `sglang/python/sglang/srt/speculative/eagle_worker.py` - -动机: - -- 需要确认这是 draft model 没构出来,还是 target/draft 接口不兼容 - -分析结果: - -- 不是 draft 没构出来 -- upstream speculative worker 在初始化 draft 时,要从 target model 取 - embedding / lm_head -- 但 `ATOM plugin wrapper` 没有把这几个接口暴露给外层 - - -### 实验 4:修复 wrapper 与 upstream speculative worker 的接口契约 - -改动文件: - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` - -改动动机: - -- 让 target wrapper 满足 upstream speculative worker 期待的最小接口 - -最终保留的最小接口: - -- `get_embed_and_head()` -- `set_embed_and_head()` -- `set_embed()` - -说明: - -- 一开始尝试把 `get_embed_and_head` 打到 inner `self.model` 上 -- 但 upstream 调的是外层 wrapper 对象 -- 所以最终改成正式的 wrapper 成员方法 - -结果: - -- `get_embed_and_head` 的报错被消除 -- 程序继续向前推进到了 speculative verify 阶段 - - -### 实验 5:新的阻塞点 - verify 路径 metadata 初始化错误 - -最新报错: - -- `AttributeError: 'NoneType' object has no attribute 'max'` - -调用栈末端: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -- `_init_extend_mla()` -- `forward_batch.extend_seq_lens.max().item()` - -初看怀疑: - -- 是不是 `batch` 为 `None` - -进一步定位后确认: - -- 不是 `batch` 为 `None` -- 也不是 `forward_batch` 为 `None` -- 真正为 `None` 的是: - - `forward_batch.extend_seq_lens` - -而且这在 `TARGET_VERIFY` 路径下是正常现象。 - - -## 为什么 `extend_seq_lens` 在 verify 里是 `None` - -文件: - -- `sglang/python/sglang/srt/speculative/eagle_info.py` -- `sglang/python/sglang/srt/managers/schedule_batch.py` - -`prepare_for_verify()` 会做: - -- 改 `batch.input_ids` -- 分配 `out_cache_loc` -- 更新 `req_to_token_pool` - -但不会去填普通 extend 用的 `extend_lens/extend_seq_lens`。 - -之后 `ScheduleBatch.get_model_worker_batch()` 会把 `self.extend_lens` -透传为 `extend_seq_lens`。 - -因此在 verify 路径下: - -- `extend_seq_lens=None` - -是完全可能且合理的。 - - -## 为什么 upstream 不会在这里崩 - -文件: - -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - -upstream 的 `AiterAttnBackend.init_forward_metadata()` 不是简单分成 -decode 和 extend 两大类,而是专门区分: - -- `decode_or_idle` -- `draft_extend` -- `target_verify` -- 普通 extend - -其中 `target_verify` 分支会自己根据: - -- `spec_info.draft_token_num` -- `forward_batch.seq_lens` - -来构造: - -- `qo_indptr` -- `kv_indptr` -- `kv_indices` - -它根本不依赖 `forward_batch.extend_seq_lens`。 - - -## 为什么 ATOM plugin 会在这里崩 - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -当前 ATOM plugin 这层把 metadata 初始化逻辑写成了: - -- `decode_or_idle` -> `_init_forward_metadata_decode()` -- 其他全部 -> `_init_forward_metadata_extend()` - -由于在 SGLang 里: - -- `ForwardMode.TARGET_VERIFY` 也被算作 `is_extend()` - -所以 verify 路径被误送进了普通 MLA extend 初始化: - -- `_init_extend_mla()` - -而这个函数又直接假设: - -- `forward_batch.extend_seq_lens` 一定存在 - -于是崩溃。 - -结论: - -- 当前问题不是“verify 输入准备坏了” -- 而是“ATOM plugin 缺 upstream 那段专门的 `TARGET_VERIFY` metadata 分支” - - -## 已做改动 - -### 文件 - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` - -### 实际改动 - -新增 wrapper 层方法: - -- `get_embed_and_head()` -- `set_embed_and_head()` -- `set_embed()` - -### 改动目的 - -- 对齐 upstream speculative worker 对 target/draft 共享 embedding/lm_head - 的接口期望 - -### 当前状态 - -- 这个改动已经验证有效 -- 运行不再卡在 `get_embed_and_head` 缺失 - - -## 当前未改的部分 - -本次刻意没有做这些事情: - -- 没有把 `DeepseekV3ForCausalLMNextN` 纳入 ATOM external package -- 没有给 plugin 接上 `ATOM DeepSeekMTP` -- 没有补 `TARGET_VERIFY` 的 metadata 初始化逻辑 -- 没有去动 draft model 的 attention backend - -原因: - -- 需要先把现有混合路径看清楚 -- 先分清楚是接口问题、metadata 问题,还是 draft model 架构问题 - - -## 目前推荐的后续推进顺序 - -### 第一步 - -先把 `TARGET_VERIFY` 在 `ATOM sgl_attn_backend.py` 中的 metadata 初始化补齐。 - -具体方向: - -- 参考 upstream - `sglang/python/sglang/srt/layers/attention/aiter_backend.py` - 的 `is_target_verify()` 分支 -- 不要再让 verify 走通用 `_init_extend_mla()` - - -### 第二步 - -验证当前混合路径是否可以完整跑完: - -- `draft -> target verify -> draft extend` - -如果这一步都跑不通,就还不适合开始切 draft 到 ATOM。 - - -### 第三步 - -在 draft 路径做架构选择: - -推荐方案: - -- 写一个 SGLang 兼容的薄 wrapper -- 内部复用 `ATOM/atom/models/deepseek_mtp.py` - -不推荐方案: - -- 在 plugin 里复制一整套新的 NextN / MTP 继承链 - - -## 关键文件索引 - -### ATOM 侧 - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` -- `ATOM/atom/plugin/register.py` -- `ATOM/atom/plugin/prepare.py` -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attention_mla.py` -- `ATOM/atom/models/deepseek_v2.py` -- `ATOM/atom/models/deepseek_mtp.py` -- `ATOM/launch_deepseek_mtp.sh` - -### upstream SGLang 侧 - -- `sglang/python/sglang/srt/configs/model_config.py` -- `sglang/python/sglang/srt/models/deepseek_nextn.py` -- `sglang/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py` -- `sglang/python/sglang/srt/speculative/eagle_worker.py` -- `sglang/python/sglang/srt/layers/attention/aiter_backend.py` -- `sglang/python/sglang/srt/managers/scheduler.py` -- `sglang/python/sglang/srt/managers/tp_worker.py` -- `sglang/python/sglang/srt/model_executor/forward_batch_info.py` -- `sglang/python/sglang/srt/speculative/eagle_info.py` - - -## 当前会话最终状态 - -- 已明确:当前 draft 不是 ATOM `DeepSeekMTP` -- 已明确:当前 draft 更可能是 upstream `DeepseekV3ForCausalLMNextN` -- 已明确:target 是 ATOM `DeepseekV3ForCausalLM` -- 已修复:wrapper 缺 `get_embed_and_head` 等接口的问题 -- 已定位:新的核心阻塞点是 `TARGET_VERIFY` metadata 初始化不完整 - -因此,本次工作最重要的阶段性成果是: - -- 把“到底是谁在跑 MTP / NextN” -- “当前失败发生在哪一层” -- “后面应该先补哪一段逻辑” - -这三件事彻底理清了。 diff --git a/work_log/MTP/MTP-2026-04-09.md b/work_log/MTP/MTP-2026-04-09.md deleted file mode 100644 index 1127bfb7d..000000000 --- a/work_log/MTP/MTP-2026-04-09.md +++ /dev/null @@ -1,715 +0,0 @@ -# 2026-04-09 ATOM Plugin 模式下 DeepSeek MTP 接入与 CUDAGraph 调试记录 - -## 目标 - -本次工作的目标是继续推进 `ATOM + SGLang plugin` 路径下的 DeepSeek MTP 接入, -并重点解决下面几个问题: - -- 让 `ATOM plugin` 在不修改 upstream `sglang` 的前提下,真正接管 DeepSeek draft/MTP 路径 -- 确认 `SGLang` 当前在 plugin 模式下到底是怎样解析 draft model 的 -- 把 `ATOM/atom/models/deepseek_mtp.py` 绑定到 `SGLang` 期望的 draft model 接口上 -- 修复接入过程中出现的 runtime / speculative / CUDAGraph 相关问题 -- 记录 `ATOM` 与 `SGLang` 在 MTP 抽象上的差异,方便后续继续演进 - - -## 本次结论速览 - -- 当前已经实现: - - 在 `ATOM plugin` 中新增 `DeepseekV3ForCausalLMNextN` wrapper - - 通过 `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` 成功覆盖 upstream 的 draft model registry - - draft 路径已不再直接使用 upstream `sglang/srt/models/deepseek_nextn.py`,而是通过 plugin wrapper 绑定到 `ATOM DeepSeekMTP` -- 当前 draft 接线方式为: - - `SGLang` 仍然认为自己在加载一个独立的 draft model:`DeepseekV3ForCausalLMNextN` - - 但这个 draft model 的内部实现已经被 ATOM plugin 替换为 `ATOM/atom/models/deepseek_mtp.py` 里的 `DeepSeekMTP` -- 当前已经修过的几类问题: - - target wrapper 缺少 `get_embed_and_head / set_embed_and_head / set_embed` - - plugin runtime 中 target / draft 共用全局 `current_atom_config`,导致 MoE 串台 - - draft runtime layer id 使用了 checkpoint 全局层号,导致 KV cache layer index 越界 - - CUDAGraph 初始化时 `TARGET_VERIFY` / `DRAFT_EXTEND` 缺少 metadata 分支 - - `sglang` 的 `RadixAttention` 默认 `k_scale/v_scale` 在 CPU,plugin wrapper 需要显式搬到 CUDA -- 当前仍未完全解决的问题集中在: - - `CUDAGraph + TARGET_VERIFY + MLA decode` 路径下,某些传给 `aiter.mla_decode_stage1_asm_fwd` 的 metadata tensor 仍然落在 CPU - - 当前已明确抓到的一个具体问题是:`kv_last_page_lens(device=cpu, ...)` -- 一个很重要的最新判断: - - 从代码看,`ATOMAttnBackendForSgl` 对 `init_cuda_graph_state()` 的 override 本身是成功的 - - 更可能的问题不是 override 语义失败,而是后续 graph metadata 组装阶段把 `forward_metadata.kv_last_page_len` 绑定成了 CPU tensor - - -## 背景知识 - -### 1. upstream SGLang 的 DeepSeek MTP 组织方式 - -upstream `SGLang` 对 DeepSeek MTP 的处理方式是: - -- draft model 会被改写成一个独立的 model architecture -- 架构名是 `DeepseekV3ForCausalLMNextN` -- 实现文件是: - - `sglang/python/sglang/srt/models/deepseek_nextn.py` - -这意味着在 `SGLang` 看来,DeepSeek MTP / NextN 不是 target model 内部的一段辅助逻辑, -而是一个独立的 draft model 类,具备: - -- 自己的 `EntryClass` -- 自己的 `load_weights(..., is_nextn=True)` -- 自己的 `forward(...)` -- 和 speculative worker 的 embed/head 共享接口 - - -### 2. ATOM 的 DeepSeek MTP 组织方式 - -`ATOM` 里对应的实现是: - -- 文件: - - `ATOM/atom/models/deepseek_mtp.py` -- 类: - - `DeepSeekMTP` - -它更像一个 draft core,而不是一个 SGLang 风格的完整 runtime model 壳子。 - -它暴露的主要接口是: - -- `forward(input_ids, positions, hidden_states, ...)` -- `compute_logits(hidden_states, spec_step_idx=...)` - -也就是说: - -- upstream `SGLang`:偏“独立 draft model 壳子” -- `ATOM`:偏“独立 draft 计算模块” - - -### 3. 为什么 plugin 里需要 wrapper - -因为 `SGLang` 期望加载的是: - -- `DeepseekV3ForCausalLMNextN` - -而 `ATOM` 现成提供的是: - -- `DeepSeekMTP` - -所以 plugin 需要补一层 very thin wrapper,把: - -- `SGLang` 的 draft model 接口 - -映射到: - -- `ATOM DeepSeekMTP` - -这也是本次新增 `deepseek_nextn_wrapper.py` 的根本原因。 - - -### 4. `EntryClass` 在 SGLang external model package 中的作用 - -`SGLang` 的 external model package 机制不是靠显式调用 register API 完成的, -而是约定: - -- 遍历 `SGLANG_EXTERNAL_MODEL_PACKAGE` -- import 这个包下面的所有 module -- 读取每个 module 的 `EntryClass` -- 用 `EntryClass.__name__` 作为 architecture 名称注册进 `ModelRegistry` - -因此,只要: - -- 文件位于 `atom.plugin.sglang.models` -- module import 成功 -- 其中声明了: - - `EntryClass = [DeepseekV3ForCausalLMNextN]` - -那么 `SGLang` 就会用这个类覆盖 upstream 同名 architecture。 - - -### 5. `NEXTN`、`EAGLE`、`EAGLEWorker`、`EAGLEWorkerV2` 的关系 - -这一点在调试中非常关键。 - -当前启动脚本里使用的是: - -- `--speculative-algorithm NEXTN` - -但在 `SGLang` 中,这个参数会被进一步改写为: - -- `EAGLE` - -而最终选择哪个 worker,要看: - -- 是否开启 spec v2 / overlap schedule - -当前这次实验里没有开启: - -- `SGLANG_ENABLE_SPEC_V2=True` - -因此当前实际使用的是: - -- `sglang/python/sglang/srt/speculative/eagle_worker.py` - -而不是: - -- `eagle_worker_v2.py` - - -## 本次主要代码改动 - -### 1. 新增 draft wrapper - -文件: - -- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` - -新增类: - -- `DeepseekV3ForCausalLMNextN` - -目的: - -- 让 `SGLang` 在解析 draft architecture 时命中 plugin 自己的 wrapper -- 在不修改 upstream `sglang` 的前提下,把 draft model 内部实现切到 `ATOM DeepSeekMTP` - -该 wrapper 目前承担的职责包括: - -- 生成 plugin 模式下的 `atom_config` -- 将 config 改写为 `deepseek_mtp` / `DeepSeekMTPModel` 语义 -- 实例化 `ATOM/atom/models/deepseek_mtp.py::DeepSeekMTP` -- 调 `setup_deepseek_for_sglang()` 做 DeepSeek MLA patch -- 暴露: - - `get_embed_and_head()` - - `set_embed_and_head()` - - `set_embed()` -- `forward()` 中消费: - - `forward_batch.spec_info.hidden_states` -- `load_weights()` 时走: - - `load_model(..., spec_decode=True)` - - -### 2. plugin runtime scope 收口 - -文件: - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` - -新增 helper: - -- `plugin_runtime_scope(...)` - -目的: - -- 不修改 `ATOM` 非 plugin 目录下全局配置实现 -- 但在 plugin 层控制: - - 当前 framework - - 当前 atom_config - -动机: - -- target wrapper 和 draft wrapper 同时存在时,共用 `ATOM` 全局 runtime state -- 会导致: - - `current_atom_config` 串台 - - MoE 静态上下文读错实例 - -实际效果: - -- target/draft 的 `__init__ / forward / load_weights` 都在 plugin scope 中运行 -- 避免 draft 初始化后把 target runtime 全局状态永久污染 - - -### 3. target wrapper 中补齐 config 绑定 - -文件: - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` - -修改点: - -- 在 `atom.prepare_model(...)` 返回后,立即抓取当前 `atom_config` -- 如果 `self.model.atom_config` 不存在,则显式补上 - -动机: - -- 避免 `setup_deepseek_for_sglang()` 回退去读全局 `get_current_atom_config()` -- 在 runtime scope 退出后出现: - - `AssertionError: Current atom config is not set` - - -### 4. draft runtime layer id 重编号 - -文件: - -- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` - -新增逻辑: - -- `_retag_mtp_runtime_layer_ids(self.model)` - -动机: - -- `ATOM DeepSeekMTP` 的 checkpoint 语义使用全局层号: - - 如 `61`, `62`, ... -- 但 `SGLang` draft worker 给 draft KV cache 分配的 layer index 是本地层号: - - `0..num_nextn_layers-1` - -此前问题: - -- runtime `layer_id` 使用了 checkpoint/global layer id -- `token_to_kv_pool.set_kv_buffer(...)` 用这个 id 访问 draft KV buffer -- 出现: - - `IndexError: list index out of range` - -修法: - -- 保留 prefix / weight name 的全局层号语义 -- 仅把 runtime attention / radix attention / nested attn 的 `layer_id / layer_num` - 改为 draft-local layer index - - -### 5. 恢复 `config.json` 中 `num_hidden_layers` - -文件: - -- `ATOM/deepseek-ai/DeepSeek-R1-0528/config.json` - -改动: - -- 临时实验中曾将: - - `num_hidden_layers: 61 -> 16` -- 后来已恢复: - - `16 -> 61` - -结论: - -- 这个参数不能当成“简化实验”的随意开关 -- 它不仅影响 model topology,也影响: - - MTP layer weight naming - - runtime/global layer numbering - - KV cache / draft wrapper 语义 - - -### 6. plugin `RadixAttention` 中强制把 `k_scale / v_scale` 放到 CUDA - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/radix_attention.py` - -问题: - -- upstream `sglang` 的 `RadixAttention` 默认会把: - - `k_scale` - - `v_scale` - 建在 CPU 上 -- plugin wrapper 之前只在它们为 `None` 时才补 CUDA 参数 -- 但实际上这两个参数“不为 None,只是在 CPU 上” - -修法: - -- `None` 时创建 CUDA 参数 -- 已存在但不在 CUDA 时,也强制 `.to("cuda")` - -动机: - -- 避免 `mla_decode_fwd` 中把 CPU scale tensor 传进 `aiter` -- 触发: - - `aiter_tensor_t only supports CUDA tensors` - - -### 7. `sgl_attn_backend.py` 中补齐 speculative CUDAGraph metadata 分支 - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -问题: - -- plugin 侧普通 runtime 的 `init_forward_metadata(...)` 已经支持: - - `decode_or_idle` - - `draft_extend` - - `target_verify` -- 但 `CUDAGraph capture/replay` 专用的 metadata 初始化函数只支持: - - `decode_or_idle` - -因此在 graph capture 阶段遇到: - -- `ForwardMode.TARGET_VERIFY` - -会直接报: - -- `ValueError: Invalid mode: forward_mode=` - -修法: - -- 在 `init_forward_metadata_capture_cuda_graph()` -- 和 `init_forward_metadata_replay_cuda_graph()` - -中补上: - -- `TARGET_VERIFY` -- `DRAFT_EXTEND` - -对应的 metadata 初始化分支 - -注意: - -- 这次补分支时又额外发现 plugin 自己的 `ForwardMetadata` 签名和 upstream 不同 -- plugin 版本额外有两个必填位置参数: - - `page_table` - - `kv_lens` -- 因此需要在新补的分支中显式补 `None, None` - - -### 8. CUDAGraph 相关深度 debug instrumentation - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -- `aiter/aiter/mla.py` - -目的: - -- 当前问题只在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径上复现 -- 普通推理路径可以正常工作 -- 因此需要抓到: - - 到底是哪一个 tensor 在进入 `aiter` kernel 前还停留在 CPU - -具体做法: - -1. 在 plugin backend `_call_mla_decode_fwd()` 里增加 tensor state dump -2. 后续发现这层不足以定位内部派生参数,于是继续下沉到: - - `aiter/aiter/mla.py::mla_decode_fwd` -3. 在真正调用: - - `aiter.mla_decode_stage1_asm_fwd(...)` - 前,检查: - - `q` - - `kv_buffer` - - `qo_indptr` - - `kv_indptr` - - `kv_indices` - - `kv_last_page_lens` - - `num_kv_splits_indptr` - - `work_meta_data` - - `work_indptr` - - `work_info_set` - - `q_scale` - - `kv_scale` - -最终定位到: - -- `kv_last_page_lens(device=cpu, dtype=torch.int32, shape=(48,), is_cuda=False)` - - -### 9. 增加断言区分 graph buffer 初始化与 metadata 绑定问题 - -文件: - -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -新增 assert 位置: - -1. `init_cuda_graph_state()` 末尾 - - 断言: - - `self.cuda_graph_kv_last_page_len.is_cuda` -2. `init_forward_metadata_capture_cuda_graph()` / `replay` - - 在构造完 `ForwardMetadata` 后断言: - - `self.forward_metadata.kv_last_page_len is None or is_cuda` - -动机: - -- 区分问题到底是: - - graph state 初始化时就落到 CPU - - 还是后续 metadata 构造时又从别的来源拿了 CPU tensor - - -## 实验过程与关键观察 - -### 实验 1:确认 draft registry 是否已切到 ATOM wrapper - -动机: - -- 在真正调 runtime 之前,先确认 `ModelRegistry` 是否已经成功把 draft architecture 指到 plugin wrapper - -方法: - -- 通过 `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` -- 检查: - - `ModelRegistry.resolve_model_cls(["DeepseekV3ForCausalLMNextN"])` - -结果: - -- 初次尝试时由于 import 链问题失败,draft 仍落回 upstream -- 修正 import 时机与 framework runtime scope 后,registry 成功解析到: - - `atom.plugin.sglang.models.deepseek_nextn_wrapper.DeepseekV3ForCausalLMNextN` - - -### 实验 2:确认 `NEXTN` 实际走的是哪个 speculative worker - -动机: - -- 判断当前问题是 `eagle_worker` 还是 `eagle_worker_v2` 路径特有 - -结果: - -- `NEXTN` 在 `SGLang` 中会先映射成 `EAGLE` -- 当前未开启 spec v2 / overlap -- 所以实际使用的是: - - `sglang/python/sglang/srt/speculative/eagle_worker.py` - -不是: - -- `eagle_worker_v2.py` - - -### 实验 3:内存初始化阶段 `mem_fraction_static` 与脚本不一致 - -现象: - -- 脚本设置: - - `--mem-fraction-static 0.9` -- runtime 报错里看到: - - `self.server_args.mem_fraction_static=0.765` - -结论: - -- 不是脚本没生效 -- 而是 `SGLang` 在 AMD + `aiter` + 长上下文模式下,会再乘一层 `0.85` - -相关逻辑: - -- 若: - - `attention_backend == "aiter"` - - 且 `context_len > 8192` -- 则: - - `mem_fraction_static *= 0.85` - - -### 实验 4:MoE 串台问题 - -现象: - -- 接上 draft wrapper 后,target 路径的 `MoE` forward 开始报: - - `KeyError: 'model.layers.3.mlp.experts'` - -结论: - -- 不是 MTP wrapper 直接把 MoE 改坏了 -- 而是: - - target / draft 共用全局 `current_atom_config` -- draft 初始化把全局配置切成了 draft config -- target 的 MoE forward 再去读全局配置时,读到了错误的 `static_forward_context` - -修法: - -- 在 plugin 层引入 `plugin_runtime_scope(...)` -- 所有 plugin wrapper 的 init / forward / load 都显式切回自己的 runtime context - - -### 实验 5:MTP runtime layer id 越界 - -现象: - -- 将 `config.json` 中 `num_hidden_layers` 改成 `16` 后,出现: - - `IndexError: list index out of range` - -分析后确认: - -- 根因不是“改成 16”本身 -- 而是 runtime `layer_id` 错误地用了 checkpoint/global layer number -- draft worker 的 KV cache 只按 draft-local 层数分配 - -结论: - -- 运行时 `layer_id` 应该是: - - `0, 1, 2, ...` -- 而不是: - - `16`, `61`, `62`, ... - - -### 实验 6:CUDAGraph `TARGET_VERIFY` 分支缺失 - -现象: - -- 开启 cuda graph 初始化时,直接报: - - `ValueError: Invalid mode: forward_mode=` - -结论: - -- plugin 侧 CUDAGraph metadata 初始化漏了 `TARGET_VERIFY` / `DRAFT_EXTEND` -- 这是 plugin 侧缺口,不是 upstream `SGLang` 自身不支持 - - -### 实验 7:`aiter_tensor_t only supports CUDA tensors` - -现象: - -- 在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径下, - `mla_decode_stage1_asm_fwd(...)` 报: - - `aiter_tensor_t only supports CUDA tensors` - -初步怀疑: - -- 可能是 plugin wrapper 的 `k_scale / v_scale` 还在 CPU - -部分修复: - -- 已在 plugin `RadixAttention` 中把 scale tensor 统一搬到 CUDA - -进一步深挖: - -- 下沉到 `aiter/aiter/mla.py` -- 发现真正触发断言的 tensor 为: - - `kv_last_page_lens(device=cpu, ...)` - - -### 实验 8:`kv_last_page_lens` 为 CPU 的进一步判断 - -关键观察: - -- plugin `ATOMAttnBackendForSgl.init_cuda_graph_state()` 中, - `self.cuda_graph_kv_last_page_len` 是按 `device=self.device` 创建的 -- 因此不太像是“子类 override 根本没生效” - -当前更强的判断是: - -- `forward_batch.attn_backend` 大概率仍然是 `ATOMAttnBackendForSgl` -- 但后续在 graph metadata 绑定 / 组装阶段,`forward_metadata.kv_last_page_len` - 被绑定成了 CPU tensor -- 也不排除: - - 某条复用父类 `forward_decode` 的路径里使用了来自父类默认初始化的 CPU graph buffer - -当前状态: - -- 已加断言区分: - - graph state 初始化阶段 - - 与 metadata 绑定阶段 -- 但在本次会话结束时,尚未拿到最终触发哪一个断言的最新日志 - - -## 当前对整体架构的理解 - -### 1. target 与 draft 在 plugin 模式下的实际形态 - -当前链路中: - -- target model: - - 由 `base_model_wrapper.py` 暴露为 `DeepseekV3ForCausalLM` - - 内部仍是 `ATOM DeepseekV3ForCausalLM` -- draft model: - - 由 `deepseek_nextn_wrapper.py` 暴露为 `DeepseekV3ForCausalLMNextN` - - 内部被绑定到 `ATOM DeepSeekMTP` - - -### 2. SGLang 与 ATOM 在 MTP 抽象上的差异 - -可以这样概括: - -- `SGLang` - - 把 MTP / NextN 视为一个独立的 runtime model - - draft worker 会单独初始化这个 model -- `ATOM` - - 把 MTP 实现成一个独立 draft core / module - - 需要由 speculative runtime 或 plugin wrapper 再包一层 - -因此: - -- `SGLang` 的差异在“接口层” -- `ATOM` 的差异在“实现层” - - -### 3. plugin 当前真正做的事情 - -当前 plugin 并不是直接修改 upstream `sglang` 的 DeepSeek MTP 逻辑,而是在三个层面做了替换: - -1. registry 层: - - 用 `EntryClass` 覆盖 `DeepseekV3ForCausalLMNextN` -2. wrapper 层: - - 用 SGLang 兼容壳把 `DeepSeekMTP` 暴露成 draft model -3. runtime 层: - - 补齐: - - embed/head sharing - - speculative hidden_states 输入 - - spec_decode 权重加载 - - runtime layer id 重编号 - - plugin runtime scope - - -## 当前仍存在的问题 - -截至本次记录结束,仍然有以下未完全解决的问题: - -- `CUDAGraph + TARGET_VERIFY + MLA decode` 路径仍存在 graph-only bug -- 当前最具体的线索是: - - `kv_last_page_lens` 在进入 `aiter.mla_decode_stage1_asm_fwd` 前为 CPU tensor -- 尚未最终确认这个 CPU tensor 是: - - 来自 graph state 初始化未正确走 plugin override - - 还是 metadata 构造过程中被其它路径重新绑定 -- 当前还没有完成的验证是: - - 让新增 assert 真正触发并给出第一现场 - - -## 本次新增 / 修改文件清单 - -### 新增 - -- `ATOM/atom/plugin/sglang/models/deepseek_nextn_wrapper.py` -- `ATOM/work_log/MTP/MTP-2026-04-09.md` - -### 修改 - -- `ATOM/atom/plugin/sglang/models/base_model_wrapper.py` -- `ATOM/atom/plugin/sglang/attention_backend/radix_attention.py` -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` -- `aiter/aiter/mla.py` -- `ATOM/deepseek-ai/DeepSeek-R1-0528/config.json` - - -## 对后续工作的建议 - -### 1. 先把 CUDAGraph 的 `kv_last_page_len` 问题钉死 - -建议下一步只做一件事: - -- 继续跑一次带最新 assert 的启动 -- 看究竟是: - - `init_cuda_graph_state()` 断言触发 - - 还是 `ForwardMetadata` 组装后的断言触发 - -这样可以把问题准确收敛到: - -- graph buffer 初始化层 -- 或 metadata 绑定层 - - -### 2. 不要再直接改 `config.json::num_hidden_layers` - -如果需要降低实验复杂度,更建议: - -- 调小: - - `--context-length` - - `--max-running-requests` - - `--chunked-prefill-size` - - `--max-total-tokens` - -而不是直接改: - -- `num_hidden_layers` - - -### 3. 区分“普通 speculative 能跑”和“CUDAGraph 也能跑” - -当前已经能说明: - -- 普通 speculative 路径和 graph 路径不是同一个问题集合 -- graph 路径会更早暴露: - - metadata device 问题 - - verify-only 分支缺失 - - graph persistent buffer 设备不一致 - -因此后续调试时建议始终把问题分成两类: - -- 非 graph runtime bug -- CUDAGraph 专有 bug - - -## 一句话总结 - -本次工作的核心进展不是“DeepSeek MTP 已完全跑通”,而是: - -- 已经把 draft architecture 从 upstream `DeepseekV3ForCausalLMNextN` - 成功接到了 `ATOM DeepSeekMTP` -- 已经把 plugin runtime scope、MTP runtime layer id、speculative metadata 等关键适配层基本搭起来 -- 当前剩余阻塞点主要集中在 `CUDAGraph + TARGET_VERIFY + MLA decode` 的 graph-only 设备与 metadata 一致性问题 - diff --git a/work_log/MTP/MTP-2026-04-10.md b/work_log/MTP/MTP-2026-04-10.md deleted file mode 100644 index 466bf8562..000000000 --- a/work_log/MTP/MTP-2026-04-10.md +++ /dev/null @@ -1,801 +0,0 @@ -# 2026-04-10 ATOM Plugin 模式下 DeepSeek MTP 的 CUDAGraph 调试、修复与知识沉淀 - -## 目标 - -今天这轮工作的目标主要有两条: - -1. 继续推进 `ATOM plugin + SGLang + DeepSeek MTP` 路径下的 `CUDAGraph` 调试,重点把昨天留下的 graph-only 问题继续收敛。 -2. 把今天在调试过程中澄清的一些关键背景知识整理成系统化文档,方便后续继续做 MTP / speculative / CUDAGraph 相关工作时复盘和学习。 - -今天聚焦的问题主要有两个: - -- 为什么此前在 `CUDAGraph + TARGET_VERIFY + MLA decode` 路径上,`kv_last_page_lens` 会落在 CPU 上。 -- 为什么后续修掉 CPU tensor 问题后,又在 `draft_extend replay` 路径上遇到 `bs=1` 但 `seq_lens.shape[0]=48` 的 shape mismatch。 - - -## 本次结论速览 - -今天最关键的结论有四个: - -1. `kv_last_page_lens(device=cpu)` 的根因不是简单的“plugin backend override 没生效”,而是: - - `EAGLE draft` 路径里使用的 `AiterMultiStepDraftBackend` - - 在其内部直接实例化了 upstream `AiterAttnBackend` - - 绕过了 plugin 通过 `"aiter"` 名字注册的 `ATOMAttnBackendForSgl` - - 结果 upstream 默认的 CPU `cuda_graph_kv_last_page_len` 泄漏进 graph 路径 - -2. `draft_extend replay` 的 `bs=1` 但 `seq_lens.shape[0]=48` 报错,本质上是: - - replay 选中的 graph bucket 是 `1` - - 但 `draft_extend cuda graph runner` 把整块静态 buffer 传给了 backend - - plugin backend 的 `draft_extend replay` 分支最初没有像 upstream 一样先做 `seq_lens[:bs]` 和 `accept_length[:bs]` 的规整 - -3. 修这个 `draft_extend replay` 问题,更合理的做法不是在函数入口打一层粗粒度的统一切片补丁,而是: - - 让 plugin 的 `init_forward_metadata_replay_cuda_graph()` 中 - - `forward_mode.is_draft_extend()` 这个分支 - - 在语义上与 upstream `AiterAttnBackend` 的对应分支对齐 - -4. 今天补充出的两篇专题文档,把下面几类知识基本梳理清楚了: - - `CUDAGraph` 在 `SGLang` 中到底固定什么 - - 为什么 `decode` 更适合做 graph - - 为什么普通 `prefill / extend` 更难 graph 化 - - `raw_bs / bs / num_tokens` - - `qo_indptr / kv_indptr / kv_indices` - - 特别是 `kv_indices` 的物理含义、shape、和 `req_to_token` 的对应关系 - - -## 今天开始时的上下文 - -昨天的工作已经完成了几件重要事情: - -- plugin draft wrapper 已接到 `ATOM DeepSeekMTP` -- `SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models` 已经生效 -- `TARGET_VERIFY / DRAFT_EXTEND` 的 graph metadata 分支已经在 plugin backend 中补齐 -- `k_scale / v_scale` 的 CPU tensor 问题已经修过 - -但昨天留下的核心阻塞点是: - -- `CUDAGraph + TARGET_VERIFY + MLA decode` 路径里 -- `aiter.mla_decode_stage1_asm_fwd(...)` 之前仍然有 metadata tensor 落在 CPU -- 明确抓到的现场是: - - `kv_last_page_lens(device=cpu, dtype=torch.int32, shape=(48,), is_cuda=False)` - -昨天的一个判断是: - -- `ATOMAttnBackendForSgl.init_cuda_graph_state()` override 本身大概率是成功的 -- 更可能的问题在 graph metadata 组装或者某条 graph-only 调用链里 - -今天的工作就是在这个基础上继续收敛。 - - -## 背景知识补充 - -### 1. 为什么今天的两个 bug 都是 graph-only bug - -今天碰到的两个 bug: - -- `kv_last_page_lens` 落在 CPU -- `draft_extend replay` 的 `bs=1` / `seq_lens.shape[0]=48` - -都有一个共同特点: - -- 普通 eager runtime 或普通 speculative 路径不一定暴露 -- 但在 `cuda graph capture / replay` 路径上会被迅速放大 - -原因在于 graph 路径有两层额外约束: - -1. 需要静态持久 buffer -2. 需要严格区分: - - 当前 bucket 的有效视图 - - 底层静态 backing buffer - -graph-only bug 往往不是“attention 算法错了”,而是: - -- backend 选型不对 -- graph state 的 persistent buffer 设备不对 -- replay 时 view 与 backing buffer 混淆 - -### 2. `raw_bs`、`bs`、`num_tokens` - -今天反复用到的三个概念: - -- `raw_bs` - - 真实 batch 中当前有多少个 request -- `bs` - - 这次 replay 选中的 graph bucket 大小 -- `num_tokens` - - 本轮真正参与 forward 的 token 数 - -这三者在 graph 路径中不一定相等。 - -特别是在 speculative 路径中: - -- `draft decode` - - `num_tokens = bs * topk` -- `draft extend` - - `num_tokens = bs * (speculative_num_steps + 1)` -- 普通 prefill - - 常常是 `sum(extend_seq_lens)` - -### 3. `kv_indices` 的意义 - -今天在整理文档时,又把 `kv_indices` 这类字段重新梳理了一遍。 - -一句话: - -- `kv_indices` 不是逻辑 token 下标 -- 它是当前这轮 attention 真正要访问的 **physical KV slot id 列表** - -它和下面几个量一起理解最清楚: - -- `req_pool_indices` -- `req_to_token` -- `seq_lens` -- `kv_indptr` - -也就是: - -- 先按 `req_pool_indices` 找到 request 在 `req_to_token` 中的那一行 -- 再按 `seq_lens[i]` 取出这行前面的有效 token -> physical slot 映射 -- 最后拼成一个一维 flatten 数组 - -这个理解对于看懂 `create_flashinfer_kv_indices_triton(...)` 非常重要。 - - -## 问题一:`kv_last_page_lens` 在 CPU 上 - -### 现象 - -从 `log.serve.log` 和之前在 `aiter/mla.py` 加的 debug 可见: - -- 进入 `mla_decode_stage1_asm_fwd` 前 -- 只有 `kv_last_page_lens` 仍然是 CPU tensor -- 其他关键张量如: - - `q` - - `kv_buffer` - - `qo_indptr` - - `kv_indptr` - - `kv_indices` - - `work_metadata` - - `q_scale` - - `kv_scale` - 都已经是 CUDA tensor - -### 初始猜想 - -最初怀疑的方向有两个: - -1. plugin backend 的 `init_cuda_graph_state()` override 根本没生效 -2. override 生效了,但 graph metadata 组装阶段又把 `forward_metadata.kv_last_page_len` 绑回了别处的 CPU tensor - -### 继续追链路后的发现 - -今天顺着 `EAGLE draft` 的 graph 初始化链继续往下看,发现真正关键的地方是: - -- `EAGLEWorker.init_attention_backend()` - - 会通过 `DraftBackendFactory.create_decode_backend()` 构造 draft 的 decode backend -- 当 backend 选择是 `"aiter"` 时 - - `DraftBackendFactory._create_aiter_decode_backend()` - - 会直接实例化 `AiterMultiStepDraftBackend` -- 而 `AiterMultiStepDraftBackend.__init__()` 内部又会直接 new: - - `AiterAttnBackend(...)` - -问题就在这里: - -- 这条 direct construction 没走 attention registry -- 所以 plugin 在 `"aiter"` 名字上注册的 `ATOMAttnBackendForSgl` - 并不会自动生效 - -换句话说: - -- target 路径可能走的是 plugin backend -- 但 draft multi-step graph 路径内部某些 step backend 仍然是 upstream `AiterAttnBackend` - -### 为什么这会导致 CPU `kv_last_page_lens` - -上游 `AiterAttnBackend.init_cuda_graph_state()` 里有: - -- `self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)` - -也就是: - -- 默认建在 CPU 上 - -而 plugin 版已经修成: - -- `torch.ones(max_bs, dtype=torch.int, device=self.device)` - -因此,一旦 `AiterMultiStepDraftBackend` 内部 step backend 实际还是 upstream 版本: - -- graph state 里的 `cuda_graph_kv_last_page_len` - 就是 CPU tensor - -这就是之前 `mla_decode_fwd` 抓到: - -- `kv_last_page_lens(device=cpu, ...)` - -的根因。 - -### 修法 - -修法放在 plugin 层,不改 upstream `sglang`: - -- 文件: - - `ATOM/atom/plugin/register.py` - -做法: - -- 继续注册 `"aiter" -> ATOMAttnBackendForSgl` -- 同时 monkeypatch: - - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` - - 让 direct import / direct construction 也落到 plugin backend - -核心代码: - -- import upstream `aiter_backend` 模块 -- `sglang_aiter_backend.AiterAttnBackend = ATOMAttnBackendForSgl` - -### 为什么这个修法合理 - -因为这个问题的本质不是: - -- metadata 算法错 - -而是: - -- graph 路径内部实际跑的 backend 实例不对 - -所以应该在 plugin 注入层统一修 backend 选型,而不是在更下游继续 patch 每个 graph state 字段。 - -### 验证 - -补了对应单测: - -- 文件: - - `ATOM/tests/plugin/test_sglang_register.py` - -新增检查: - -- 除了验证 `"aiter"` 名字的 registry 绑定 -- 还显式验证: - - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` - - 也被换成了 plugin backend - -执行: - -- `pytest -q ATOM/tests/plugin/test_sglang_register.py` - -结果: - -- `9 passed` - - -## 问题二:`draft_extend replay` 中 `bs=1`,但 `seq_lens.shape[0]=48` - -### 现象 - -在下一轮服务运行中,又遇到一个新的 graph-only 错误: - -- `draft_extend replay` -- `init_forward_metadata_replay_cuda_graph()` -- `kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)` -- 报: - - `Target sizes: [1]` - - `Tensor sizes: [48]` - -也就是: - -- 当前 replay bucket 是 `bs = 1` -- 但传进来的 `seq_lens` 仍然是长度 `48` 的静态 backing buffer - -### 为什么会这样 - -`draft_extend cuda graph runner` 在 replay 时确实是这样传参的: - -- 会先把真实 batch 数据 copy 到静态 buffer 的前缀 -- 但调用 backend 时,传的是整块 `buffers.seq_lens` / `buffers.req_pool_indices` -- 而不是 `[:bs]` 视图 - -这意味着: - -- caller 传的是 backing buffer -- callee 却按“当前 bucket 的有效 view”去理解 - -于是就会出现: - -- 左边 slice 长度是 `1` -- 右边 `cumsum(seq_lens)` 长度是 `48` - -### 一开始的临时修法 - -我最初为了快速兜住问题,尝试过: - -- 在 plugin backend 的 `init_forward_metadata_replay_cuda_graph()` 入口 -- 对 `req_pool_indices / seq_lens / seq_lens_cpu` - 统一做 `[:bs]` - -这个做法能修掉 shape mismatch,但后来用户指出了一个更重要的问题: - -- 这里不应该只做粗粒度防御性 patch -- 更应该看 upstream 同分支是怎么处理的 - -这是对的。 - -### 继续对 upstream 后的发现 - -upstream `AiterAttnBackend.init_forward_metadata_replay_cuda_graph()` 的 -`draft_extend` 分支里,语义并不只是: - -- `seq_lens = seq_lens[:bs]` - -还包括: - -- `accept_lens = spec_info.accept_length[:bs]` -- `qo_indptr[1:] = cumsum(accept_lens)` - -也就是说,upstream 在 replay 阶段的 query 分段不是简单固定步长, -而是和 `accept_length` 绑定。 - -### 最终修法 - -因此最终采用的不是“函数入口统一切片”的粗 patch,而是: - -- 让 plugin 的 - - `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - 中 - - `init_forward_metadata_replay_cuda_graph()` - 的 `forward_mode.is_draft_extend()` 分支 - 在语义上对齐 upstream - -具体做了两件事: - -1. `seq_lens = seq_lens[:bs]` -2. `accept_lens = spec_info.accept_length[:bs]` -3. `qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)` - -并保留 plugin 自己的: - -- `ForwardMetadata` 字段布局 -- MLA persistent kernel metadata 生成 -- `kv_last_page_len.is_cuda` 断言 - -### 一个小插曲:`qo_indptr[0] = 0` - -在第一次改这个分支时,我曾额外加过: - -- `qo_indptr[0] = 0` - -后来和 upstream 对比后又删掉了,原因是: - -- upstream 没有这行 -- `self.qo_indptr` 在父类中本来就是 `torch.zeros(...)` 初始化 -- 这类 CSR/indptr buffer 在很多地方都只写 `1:` - -因此: - -- 这行虽然防御性上没错 -- 但既然目标是与 upstream 严格语义对齐,就不应额外保留 - - -## 关于“为什么 graph 路径里 metadata 可以每轮重建” - -今天还专门澄清了一个很重要的问题: - -- graph replay 里 metadata 明明每轮都会重新构造 -- 为什么 decode / draft_extend 还能做 graph -- 而 prefill 更难 - -### 结论 - -不是“metadata 变了就不能 graph”,而是要区分: - -1. metadata 的**内容**变 -2. metadata 是否会进一步影响: - - graph 内部分支 - - 中间 tensor shape - - workspace shape - - kernel launch 形态 - -对于 decode / draft_extend graph 路径: - -- metadata 的构造发生在 graph 外 -- graph 内看到的是固定地址、固定 shape 的 persistent buffer -- replay 只是改这些 buffer 的内容 - -因此 graph 仍然可复用。 - -而普通 prefill: - -- `total_tokens` -- `max_q_len` -- `max_kv_len` -- `qo_indptr` -- `kv_indptr` -- workspace 大小 -都可能跟 batch 一起变化 - -这时 metadata 已经不只是“参数”,而更像“问题几何结构”的一部分。 - -所以: - -- 普通 prefill 仍然难 graph -- 即使你不考虑不同 kernel path - - -## 今天补出的三篇背景文档 - -今天除了修 bug,还补了三份面向复盘的文档。 - -### 1. `2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md` - -主题: - -- `CUDAGraph` 在 `SGLang` 中固定的到底是什么 -- 为什么 `decode` 更适合做 graph -- 为什么普通 `prefill / extend` 难 graph -- `raw_bs / bs / num_tokens` -- `SGLang` 在 decode 阶段怎样做 capture / replay -- `ForwardMetadata` 在 graph capture / replay 中扮演什么角色 - -适合什么时候看: - -- 想从整体架构层面理解 `graph + attention metadata` - -### 2. `2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md` - -主题: - -- 只收窄到“最简单 prefill” -- 不考虑不同 kernel path -- 不考虑 prefix cache -- 不考虑 speculative -- 解释: - - 为什么即使是最简单 prefill 也难 graph - - `qo_indptr / kv_indptr / kv_indices / max_q_len / max_kv_len` - 的 shape 与物理意义 - - 给多个可手算例子 - -特别补充了: - -- `kv_indices` 的详细解释 -- `req_to_token` -- `req_pool_indices` -- 以及 `qo_indptr + kv_indptr + kv_indices` 联合看时的完整例子 - -适合什么时候看: - -- 想快速回忆 metadata 的 shape 和物理含义 - -### 3. `MTP-2026-04-10.md` - -也就是本文件,作为今天的完整工作日报。 - - -## 今天的代码改动清单 - -### 1. `ATOM/atom/plugin/register.py` - -改动: - -- 除了继续通过 registry 把 `"aiter"` 绑定到 `ATOMAttnBackendForSgl` -- 还显式把 upstream 模块符号: - - `sglang.srt.layers.attention.aiter_backend.AiterAttnBackend` - 重新绑定到 plugin backend - -目的: - -- 修复 `AiterMultiStepDraftBackend` 内部 direct construction 绕过 registry 的问题 - -### 2. `ATOM/tests/plugin/test_sglang_register.py` - -改动: - -- 扩展测试覆盖 -- 不仅验证 registry name 是 `"aiter"` -- 也验证 `AiterAttnBackend` 模块符号被替换成了 plugin backend - -目的: - -- 防止以后又出现“registry 绑了,但 direct construction 仍然绕过 plugin”的回归 - -### 3. `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -改动: - -- 继续保留 graph metadata 断言 -- 把 `init_forward_metadata_replay_cuda_graph()` 中 - `forward_mode.is_draft_extend()` 分支 - 调整为更接近 upstream 的语义: - - `seq_lens = seq_lens[:bs]` - - `accept_lens = spec_info.accept_length[:bs]` - - `qo_indptr[1:] = cumsum(accept_lens)` - - `kv_indptr[1:] = cumsum(seq_lens)` - -目的: - -- 修复 `draft_extend replay` 中 - - `bs=1` - - 但 `seq_lens.shape[0]=48` - 的 graph-only mismatch - -### 4. 新增文档 - -- `ATOM/work_log/MTP/2026-04-10-sglang-cudagraph-prefill-decode-metadata-guide.md` -- `ATOM/work_log/MTP/2026-04-10-sglang-simple-prefill-cudagraph-metadata-guide.md` -- `ATOM/work_log/MTP/MTP-2026-04-10.md` - - -## 今天做过但没有保留的尝试 - -今天有一个短暂尝试后来撤回了: - -- 我曾经直接修改过 upstream: - - `sglang/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py` -- 试图在 caller 侧把传给 backend 的 `buffers.seq_lens` / `buffers.req_pool_indices` - 改成 `[:bs]` - -这个改法从技术上能工作,但不符合这次工作的原则: - -- 用户要求不要改 upstream `sglang` -- 这类修法也会把 plugin 自己对 graph 语义的偏差掩盖掉 - -因此后来把它撤回了,最终修法完全落在 plugin 内部。 - - -## 实验过程与结果 - -### 实验 1:确认 `kv_last_page_lens` 问题是不是 graph state 初始化失败 - -动机: - -- 区分问题到底是: - - plugin backend 的 graph state 初始化失败 - - 还是后续调度/metadata 链路里又掉回 upstream backend - -方法: - -- 检查 plugin backend 中 `init_cuda_graph_state()` 的实现 -- 对照 upstream `AiterAttnBackend` -- 追 `EAGLEWorker -> DraftBackendFactory -> AiterMultiStepDraftBackend` - 的实例化链路 - -结果: - -- 发现不是简单的 override 失败 -- 而是 `AiterMultiStepDraftBackend` 内部 direct new `AiterAttnBackend` - 绕过了 plugin registry - -结论: - -- 这是 backend 注入点不完整导致的 graph-only bug - -### 实验 2:确认修 backend 注入后,下一处报错落在哪里 - -动机: - -- 修完 CPU tensor 问题后,需要继续看下一层 graph-only 问题 - -方法: - -- 根据新的 `log.serve.log` 栈追 `draft_extend replay` - -结果: - -- 抓到新的错误: - - `kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)` - - `Target sizes: [1]` - - `Tensor sizes: [48]` - -结论: - -- graph runner 的 backing buffer 和当前 bucket view 混淆了 - -### 实验 3:先做临时入口兜底,再回看 upstream 语义 - -动机: - -- 先快速验证 shape mismatch 是否确实是 view 问题 - -方法: - -- 在 plugin `init_forward_metadata_replay_cuda_graph()` 入口加统一 `[:bs]` - 规整 - -结果: - -- 从逻辑上能兜住这类 mismatch - -但随后进一步对照 upstream 发现: - -- 这个问题不只是切片 -- 还牵涉到 `draft_extend replay` 下 `qo_indptr` 的语义应该由 - `accept_length` 决定 - -结论: - -- 函数入口打补丁只是临时兜底 -- 更好的修法是按 upstream 分支语义对齐 - -### 实验 4:对齐 plugin `draft_extend replay` 分支到 upstream - -动机: - -- 让 plugin 分支和 upstream 的 graph replay 语义一致 - -方法: - -- 把 plugin 的 `draft_extend replay` 分支改为: - - `seq_lens = seq_lens[:bs]` - - `accept_lens = spec_info.accept_length[:bs]` - - `qo_indptr[1:] = cumsum(accept_lens)` - -结果: - -- 代码语义与 upstream 一致性明显更强 -- 且不需要修改 upstream `sglang` - -### 实验 5:运行服务并看最新日志 - -动机: - -- 看修复后服务是否恢复到正常请求处理状态 - -结果: - -- `log.serve.log` 最新尾部可见: - - Prefill batch 仍显示: - - `cuda graph: False` - - Decode batch 显示: - - `cuda graph: True` - - 多个请求成功返回 `200 OK` - -这至少说明: - -- 当前服务已经重新进入了正常请求处理状态 -- prefill 依旧不走 graph,这与当前系统设计一致 -- decode graph 已经在工作 - -需要注意: - -- 今天没有做系统化 benchmark -- 这里只能说明日志上服务在继续跑,不能说明所有 corner case 都完全验证完成 - - -## 今天形成的理解:为什么 decode graph 可以,而 prefill 更难 - -今天围绕用户问题,又把这件事重新总结了一遍。 - -### 1. decode graph 为什么更自然 - -因为 decode 常常满足: - -- 每个 request 每轮 query token 数固定 -- `num_tokens_per_bs` 固定 -- `max_q_len` 常常固定为 `1` -- metadata 更多是在固定地址上的输入数据 - -因此: - -- 可以通过 `bs bucket + 静态 buffer + metadata 重建` - 来复用 graph - -### 2. prefill graph 为什么更难 - -即使不考虑 prefix cache 和不同 kernel path,最简单 prefill 也会遇到: - -- `total_tokens = sum(extend_seq_lens)` 会变 -- `qo_indptr` 会变 -- `max_q_len` 会变 -- `max_kv_len` 会变 -- 中间 tensor / workspace shape 也会变 - -所以它的 challenge 不只是: - -- metadata 值变了 - -而是: - -- metadata 连同问题几何结构一起变了 - -这会让: - -- graph 内部张量 shape -- workspace 形状 -- 有时连 kernel launch 计划 -都跟着变化 - - -## 当前状态判断 - -截至今天结束,可以比较有把握地说: - -1. DeepSeek MTP 的 plugin draft 路径接线已经比昨天更稳固: - - direct construction 绕过 plugin backend 的坑已经堵住 - -2. `draft_extend replay` 的 graph metadata 语义也更接近 upstream: - - 不再是一个仅靠入口切片兜底的 patch - - 而是按 upstream 分支语义修正 - -3. 从最新日志看: - - prefill 继续按设计走 `cuda graph: False` - - decode 已经在 `cuda graph: True` - - 服务能正常响应请求 - -4. 但今天没有完成的事情仍然有: - - 没有做系统化压测或 benchmark - - 没有确认所有 speculative 相关 corner case 都已经覆盖 - - 没有进一步处理“普通 prefill 是否值得 graph 化”的工程设计 - - -## 本次执行过的验证命令 - -今天执行过的本地验证主要包括: - -- `pytest -q ATOM/tests/plugin/test_sglang_register.py` - - 结果:通过 -- `python3 -m py_compile` - - 对修改过的 plugin 文件做语法检查 - - 结果:通过 - -另外: - -- 通过 `log.serve.log` 持续跟踪服务运行结果 -- 从日志尾部确认: - - `Prefill batch ... cuda graph: False` - - `Decode batch ... cuda graph: True` - - 多个 `/v1/completions` 返回 `200 OK` - - -## 对后续工作的建议 - -### 1. 继续观察真实服务日志 - -虽然今天的两个 graph-only bug 已经定位并修正,但建议继续跑一段时间,看是否还有新的: - -- speculative-only -- graph-only -- replay-only - -问题继续冒出。 - -### 2. 如果再出现 graph-only 问题,优先检查两类契约 - -今后的 graph 调试,优先看: - -1. backend 注入契约 - - 实际实例是不是 plugin backend -2. backing buffer / bucket view 契约 - - caller 传的是整块静态 buffer,还是当前 `[:bs]` view - -很多 graph-only bug 最后都归到这两类。 - -### 3. 如需进一步探索 prefill graph,可先限定一个最简单子场景 - -如果未来要继续研究: - -- “普通 prefill 是否也能做 graph” - -建议不要直接想做“所有 prefill graph”,而是先限定: - -- 无 prefix cache -- 固定 kernel path -- 小量 bucket - -然后再评估: - -- `total_tokens` -- `max_q_len` -- `max_kv_len` -- workspace - -能否通过 bucket 化或 pad/unpad 收敛。 - - -## 一句话总结 - -今天最核心的进展不是“DeepSeek MTP 的所有问题都已解决”,而是: - -- 把 `kv_last_page_lens` 掉到 CPU 的 graph-only 根因准确收敛到 - `AiterMultiStepDraftBackend` 绕过 plugin backend -- 把 `draft_extend replay` 的 `bs=1 / seq_lens=48` 问题从临时补丁, - 收敛为与 upstream 语义对齐的修法 -- 同时把 `CUDAGraph / decode / prefill / metadata / kv_indices` - 这一整套背景知识整理成了更适合后续复盘的文档体系