[XPU] add speculate_get_logits#5497
Conversation
|
Thanks for your contribution! |
| baidu::xpu::api::Context* ctx = | ||
| static_cast<const phi::XPUContext*>(dev_ctx)->x_context(); | ||
| if (draft_logits.is_cpu()) { | ||
| ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); |
There was a problem hiding this comment.
这个 context 什么时候被销毁掉呢?是否会造成内存泄露?
There was a problem hiding this comment.
CPU一般是用于单测验证,除了这里其他的算子可能也没对这个cpu的ctx做释放,后续可能需要统一排查一下
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #5497 +/- ##
==========================================
Coverage ? 60.27%
==========================================
Files ? 329
Lines ? 41114
Branches ? 6261
==========================================
Hits ? 24782
Misses ? 14443
Partials ? 1889
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| seq_lens_encoder, | ||
| real_bsz, | ||
| vocab_size); | ||
| WRAPPER_DUMP(ctx); |
| if (clus_id < real_bsz && cid == 0) { | ||
| GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int)); | ||
| GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int)); | ||
| int next_token_num_previous = 0; | ||
| for (int bid = 0; bid < real_bsz; bid++) { | ||
| sm_batch_token_num[bid] = | ||
| sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid]; | ||
| if (bid == 0) { | ||
| sm_cu_batch_token_offset[bid] = 0; | ||
| sm_cu_next_token_offset[bid] = 0; | ||
| } else { | ||
| sm_cu_batch_token_offset[bid] = | ||
| sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1]; | ||
| sm_cu_next_token_offset[bid] = | ||
| sm_cu_next_token_offset[bid - 1] + next_token_num_previous; | ||
| } | ||
| next_token_num_previous = | ||
| sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid]; | ||
| } | ||
| mfence_sm(); | ||
| if (clus_id == 0) { | ||
| SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int)); | ||
| SM2GM_ASYNC(sm_cu_batch_token_offset, | ||
| cu_batch_token_offset, | ||
| real_bsz * sizeof(int)); | ||
| } | ||
| } |
There was a problem hiding this comment.
这个部分的代码逻辑,21行-40行是每个 cluster 都会执行,41-46行只有 cluster0 会执行,是不是等价于21行-46行的代码实际上只有 clus_id == 0 执行的有用?
There was a problem hiding this comment.
这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0
There was a problem hiding this comment.
这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0
好的,明白了
Motivation
add speculate_get_logits
Modifications
add speculate_get_logits
Usage or Command
No
Accuracy Tests
No
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.