Skip to content

Conversation

@freeliuzc
Copy link
Collaborator

@freeliuzc freeliuzc commented Dec 24, 2025

Motivation

  1. 在多步MTP下,mask_attn_offset 没有回退,导致 MTP 在推理时使用了错误的 attn_mask
  2. 在PD分离模式下,由于P只推理一次,则需要额外考虑

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings December 24, 2025 05:47
@paddle-bot
Copy link

paddle-bot bot commented Dec 24, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes the attention mask offset calculation for multi-step MTP (Multi-Token Prediction) in mixed and PD-split (Prefill-Decode split) modes of speculative decoding. The fix addresses incorrect mask rollback behavior when the draft model operates in these specific configurations.

Key Changes:

  • Modified mask_rollback calculation in PD-split mode to account for num_model_steps
  • Added mask_rollback parameter propagation through the CUDA kernel pipeline
  • Added debug logging for troubleshooting attn_mask_offsets and sequence lengths

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
fastdeploy/spec_decode/mtp.py Updated mask_rollback calculation formula for PD-split mode, added mask_rollback to kernel inputs, and included debug logging
custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu Added mask_rollback parameter to kernel signature and implemented mask_rollback accumulation logic in decode generation path; includes code formatting improvements
custom_ops/gpu_ops/cpp_extensions.cc Added mask_rollback parameter to DraftModelPreprocess function signature for consistency
Comments suppressed due to low confidence (1)

fastdeploy/spec_decode/mtp.py:883

  • These debug logging statements should be removed before merging to production. Debug logs at the info level in performance-critical code paths can significantly impact performance, especially when they execute on every substep iteration.

                # Initialize forward meta data
                self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)

Comment on lines 565 to 572
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
# -(self.speculative_config.num_model_steps - 1) :
# 1. draft_model_preprocess will rollback (num_model_steps - 1) in each Step. But In P-D splitewise,
# 2. P only generate one token, so we need to minus it
self.model_inputs["mask_rollback"][idx : idx + 1] = 1 - (
self.speculative_config.num_model_steps - 1
)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation logic for mask_rollback is complex and would benefit from clearer documentation. The comment mentions two points but the explanation is somewhat unclear. Consider adding:

  1. A more detailed explanation of why we subtract (num_model_steps - 1)
  2. What the relationship is between P-D splitwise mode and num_model_steps
  3. An example calculation to illustrate the expected values

Copilot uses AI. Check for mistakes.
Comment on lines 565 to 572
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
# -(self.speculative_config.num_model_steps - 1) :
# 1. draft_model_preprocess will rollback (num_model_steps - 1) in each Step. But In P-D splitewise,
# 2. P only generate one token, so we need to minus it
self.model_inputs["mask_rollback"][idx : idx + 1] = 1 - (
self.speculative_config.num_model_steps - 1
)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description is incomplete. According to the PR template and custom guidelines, the description should include:

  1. Motivation: A clear explanation of what problem this PR solves
  2. Modifications: Detailed description of the changes made
  3. Usage or Command: How to test or use these changes
  4. Accuracy Tests: Results demonstrating the fix works correctly

Currently, all these sections are empty. Please provide:

  • Why the attn_mask_offset calculation was incorrect
  • What scenarios were failing before this fix
  • Test results showing the issue is resolved

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +297 to +326
process_splitwise_prefill<BlockSize,
TRUNCATE_FIRST_TOKEN,
KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process_splitwise_prefill kernel doesn't receive or use the mask_rollback parameter, while the draft_model_preprocess_kernel does. This asymmetry could be problematic if mask_rollback adjustments are needed in splitwise_prefill mode. Please verify whether:

  1. The mask_rollback logic is intentionally not needed for splitwise_prefill scenarios
  2. If it is needed, the parameter should be added to this kernel as well

If it's intentional that splitwise_prefill doesn't need mask_rollback, consider adding a comment explaining why.

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

codecov-commenter commented Dec 24, 2025

Codecov Report

❌ Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@8beb015). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/spec_decode/mtp.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5738   +/-   ##
==========================================
  Coverage           ?   65.91%           
==========================================
  Files              ?      330           
  Lines              ?    41819           
  Branches           ?     6406           
==========================================
  Hits               ?    27567           
  Misses             ?    12210           
  Partials           ?     2042           
Flag Coverage Δ
GPU 65.91% <50.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yuanlehome yuanlehome merged commit 9018ccf into PaddlePaddle:develop Dec 25, 2025
23 of 30 checks passed
freeliuzc added a commit to freeliuzc/FastDeploy that referenced this pull request Dec 26, 2025
…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
freeliuzc added a commit to freeliuzc/FastDeploy that referenced this pull request Dec 26, 2025
…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
EmmonsCurse pushed a commit that referenced this pull request Dec 26, 2025
…ed and PD-split modes (#5738) (#5792)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
Deleter-D pushed a commit to Deleter-D/FastDeploy that referenced this pull request Dec 29, 2025
…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
Deleter-D pushed a commit to Deleter-D/FastDeploy that referenced this pull request Dec 29, 2025
…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
Deleter-D added a commit to Deleter-D/FastDeploy that referenced this pull request Dec 29, 2025
ckl117 pushed a commit to fxyfxy777/FastDeploy that referenced this pull request Dec 29, 2025
…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
yuanlehome pushed a commit that referenced this pull request Dec 30, 2025
* [Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register

* fix entropy bugs

* Revert "[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)"

This reverts commit ba0d35a.

* fix ut

* fix

---------

Co-authored-by: freeliuzc <lzc842650834@gmail.com>
yuanlehome added a commit that referenced this pull request Jan 5, 2026
…ed and PD-split modes (#5738) (#5793)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants