-
Notifications
You must be signed in to change notification settings - Fork 693
[Speculative Decoding]Support multi-step mtp with cudagraph #5624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative Decoding]Support multi-step mtp with cudagraph #5624
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for multi-step MTP (Multi-Token Prediction) with CUDAGraph. The changes enable proper capture of CUDA graphs for MTP scenarios by adjusting capture sizes to account for multiple tokens generated per query per step.
Key Changes
- Modified CUDA graph capture logic for MTP target model to use dynamic batch size calculation based on speculative token count
- Updated
_set_cudagraph_sizesto generate capture sizes scaled by tokens per query per step - Simplified target model capture by removing special handling for batch size 1
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
fastdeploy/worker/gpu_model_runner.py |
Simplified MTP target model capture logic, removed batch size 1 skip condition, updated batch size and expected_decode_len calculations |
fastdeploy/config.py |
Added dec_token_per_query_per_step parameter to scale CUDA graph capture sizes appropriately for multi-step MTP |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #5624 +/- ##
==========================================
Coverage ? 62.88%
==========================================
Files ? 329
Lines ? 41700
Branches ? 6368
==========================================
Hits ? 26223
Misses ? 13492
Partials ? 1985
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
gongshaotian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ddle#5624) * support multi-step mtp with cudagraph * fix usage * fix unit test
…ddle#5624) * support multi-step mtp with cudagraph * fix usage * fix unit test
| if batch_size == 1: | ||
| logger.info("Skip token_num = 1, when capture Draft model for mtp") | ||
| else: | ||
| assert batch_size % 2 == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert 删掉
| if self.scheduler_config.splitwise_role == "decode" | ||
| else self.scheduler_config.max_num_batched_tokens | ||
| ), | ||
| batch_size=int(batch_size / 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
| ), | ||
| batch_size=int(batch_size / 2), | ||
| in_capturing=True, | ||
| expected_decode_len=3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里以及 _dummy_run() 的退出逻辑 需要改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 + draft token + draft model eos token
| logger.info( | ||
| f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" | ||
| ) | ||
| if self.graph_opt_config.draft_model_use_cudagraph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
打开这个启动参数
…ddle#5624) * support multi-step mtp with cudagraph * fix usage * fix unit test
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.