Skip to content

Conversation

@freeliuzc
Copy link
Collaborator

  • support multi-step mtp with cudagraph

  • fix usage

  • fix unit test

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

…ddle#5624)

* support multi-step mtp with cudagraph

* fix usage

* fix unit test
Copilot AI review requested due to automatic review settings December 22, 2025 03:41
@paddle-bot
Copy link

paddle-bot bot commented Dec 22, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR is a cherry-pick from #5624 that adds support for multi-step MTP (Multi-Token Prediction) with CUDA graph optimization. The changes modify how CUDA graph capture sizes are calculated and how the target model is warmed up during the capture process to support multi-step speculative decoding with MTP.

Key Changes:

  • Modified CUDA graph capture size calculation to account for multiple tokens per query per step in MTP scenarios
  • Updated target model warm-up logic to use dynamic batch size calculation based on speculative tokens
  • Removed the skip condition for batch_size=1 in target model capture

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
fastdeploy/config.py Added dec_token_per_query_per_step parameter to _set_cudagraph_sizes method to multiply capture sizes by the number of tokens per step, and updated initialization logic to compute this value based on speculative configuration
fastdeploy/worker/gpu_model_runner.py Simplified target model capture logic by removing batch_size=1 skip condition and updating batch size calculation to divide by (num_speculative_tokens + 1) instead of hardcoded division by 2

self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape)
dec_token_per_query_per_step = (
self.speculative_config.num_speculative_tokens + 1
if self.speculative_config is not None and self.speculative_config.method is not None
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.speculative_config.method is not None is too broad and should be more specific. This will cause dec_token_per_query_per_step to be calculated for any speculative method, not just "mtp". The logic should check if self.speculative_config.method == "mtp" to align with the specific MTP requirements in lines 1571-1574 above.

Suggested change
if self.speculative_config is not None and self.speculative_config.method is not None
if self.speculative_config is not None and self.speculative_config.method == "mtp"

Copilot uses AI. Check for mistakes.
Comment on lines +2089 to +2095
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.scheduler_config.max_num_batched_tokens
),
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The division operation int(capture_size / (self.speculative_config.num_speculative_tokens + 1)) can result in truncation if capture_size is not evenly divisible. This could lead to incorrect batch sizes being used for warming up the target model. Consider adding a validation check or assertion to ensure that capture_size is divisible by (num_speculative_tokens + 1), similar to the assertion on line 1575 in config.py that validates the max_capture_shape is even.

Suggested change
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.scheduler_config.max_num_batched_tokens
),
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
group_size = self.speculative_config.num_speculative_tokens + 1
assert (
capture_size % group_size == 0
), f"cudagraph capture_size {capture_size} must be divisible by (num_speculative_tokens + 1) = {group_size}"
self._dummy_run(
num_tokens=(
self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if self.scheduler_config.splitwise_role == "decode"
else self.scheduler_config.max_num_batched_tokens
),
batch_size=capture_size // group_size,

Copilot uses AI. Check for mistakes.
Comment on lines 906 to 907
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for this method should be updated to describe the new dec_token_per_query_per_step parameter. The docstring currently doesn't mention this parameter or explain its purpose in the capture size calculations.

Suggested change
Calculate a series of candidate capture sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
Calculate a series of candidate CUDA graph capture sizes and
extract a subset of them as the final capture list based on user input.
The method builds a base sequence of capture sizes (e.g. 1, 2, 4, 8, 16, …)
and multiplies every element by ``dec_token_per_query_per_step`` so that all
capture sizes correspond to a specific number of decoding tokens per query
per step.
Args:
max_capture_size (int): The maximum capture size to consider. This value
is always included in the candidate list.
dec_token_per_query_per_step (int): Number of decoding tokens per query
per step. All candidate capture sizes are scaled by this factor, so
each capture size is an integer multiple of this value.

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 90.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.4@9c55bc3). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/config.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.4    #5670   +/-   ##
==============================================
  Coverage               ?   58.95%           
==============================================
  Files                  ?      327           
  Lines                  ?    40765           
  Branches               ?     6200           
==============================================
  Hits                   ?    24031           
  Misses                 ?    14860           
  Partials               ?     1874           
Flag Coverage Δ
GPU 58.95% <90.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@freeliuzc freeliuzc merged commit ceafd75 into PaddlePaddle:release/2.4 Dec 23, 2025
35 of 39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants