Skip to content

Conversation

@rainyfly
Copy link
Collaborator

@rainyfly rainyfly commented Jul 20, 2025

Background

In the current version, blocks are devided into two tables. The table for prefilling is managed by server, and the one for decoding is managed by worker. Users should set kv_cache_ratio to determine how many blocks are used by each phase.
And when blocks for decoding are not enough, blocks scheduled for decoding are performed by ops step_paddle, which is easy to cause OOM when output length is large.

To make deployment more easier and stable, we start to develop scheduler v1. Only one block table exists and managed by server.
Note
The scheduler v1 is experimentally developed and only supported for mixed and sampling scenerio currently. In the feature, we will continue to build and improve it.

Perf

  • when kv_cache_ratio is set appropriately,the performance with scheduler v1 is worse by 3-5% than before due to more communitcation overhead between engine and worker.
  • when kv_cache_ratio is too high or too low, the performance with scheduler v1 is better.

Stability

  • OOM did not appear in testing.

How to enable

Set environment variable ENABLE_V1_KVCACHE_SCHEDULER to 1 to enable scheduler v1.

export ENABLE_V1_KVCACHE_SCHEDULER=1

@paddle-bot
Copy link

paddle-bot bot commented Jul 20, 2025

Thanks for your contribution!

@rainyfly rainyfly requested a review from Jiang-Jia-Jun July 20, 2025 16:28
@Jiang-Jia-Jun Jiang-Jia-Jun requested a review from Copilot July 21, 2025 02:05
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces experimental support for block scheduler v1 in FastDeploy, which unifies block management under a single server-managed cache system. The new scheduler eliminates the need for manual kv_cache_ratio configuration and aims to improve stability by preventing OOM issues during long sequence generation.

  • Adds a new unified block scheduling system managed entirely by the server
  • Implements conditional logic throughout the codebase to support both legacy and v1 schedulers
  • Introduces new CUDA operations for v1-specific input handling and task recovery

Reviewed Changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
fastdeploy/envs.py Adds environment variable to enable v1 scheduler
fastdeploy/engine/engine.py Integrates v1 resource manager and scheduler logic
fastdeploy/engine/sched/resource_manager_v1.py Core v1 resource manager implementation
fastdeploy/worker/gpu_worker.py Adds v1 task insertion logic
fastdeploy/worker/gpu_model_runner.py Implements v1 task processing and input management
fastdeploy/model_executor/pre_and_post_process.py Updates post-processing for v1 scheduler
fastdeploy/output/token_processor.py Adds v1-specific resource recycling
fastdeploy/engine/request.py Extends Request class with scheduler state
custom_ops/ New CUDA operations for v1 input handling
Comments suppressed due to low confidence (1)

fastdeploy/worker/gpu_model_runner.py:269

  • The variable 'i' is used in a nested loop context where there's already an outer loop variable 'i' at line 199. This creates a naming conflict and potential confusion. Consider renaming the inner loop variable to 'j' or 'seq_idx'.
                for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):

Comment on lines +194 to +197
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After merging #2924, we need to delete the branch.

) # All gpu blocks are managed by cache manager
else:
self.num_gpu_blocks = cache_config.prefill_kvcache_block_num
self.gpu_free_block_list = list(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和上面的分支冗余了,可以提到下面去。

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 85a78d6 into PaddlePaddle:develop Jul 23, 2025
5 of 6 checks passed
lizexu123 pushed a commit to lizexu123/FastDeploy that referenced this pull request Jul 25, 2025
* Support FD block scheduler v1

* Support FD block scheduler v1

* Support FD block scheduler v1

* Fix according to copilot review

* Fix according to review

* Remove is_dummy

* Fix bug when real_bsz=1

* Fix infer first token cost time

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
luukunn added a commit to luukunn/FastDeploy that referenced this pull request Jul 29, 2025
* [MTP Fix] Fix code and register cpp operators (PaddlePaddle#2965)

* fix rl config local rank (PaddlePaddle#2957)

* [FIX]fix rejection sampling when topp=0 using _SAMPLING_EPS (PaddlePaddle#2967)

* fix rejection sampling when topp=0

* fix

* [SOT] Add sot warmup (NVIDIA GPU Only) (PaddlePaddle#2929)

* add sot warmup

* fix code style

* change batch_size list

* add param to config

* rm free_list settings && set sot_warmup_sizes

* finish debug with dynamic dims by type annotations

* add profile_run guard

* rm sth useless

* support chunk_prefill in fa3

* 【Infer】Improve the performance block_wise_fp8 of triton_moe_backend (PaddlePaddle#2942)

* Update README.md

* Update README.md

* delete max-len (PaddlePaddle#2959)

* [CI] add codestyle_check action (PaddlePaddle#2972)

* [CI] add codestyle_check action

* [CI] Integrate codestyle check via pre-commit in GitHub Actions

* fix mtp bug in pd-split mode (PaddlePaddle#2970)

* [BugFix] Add prefill restrictions for chunked_prefill+VL (PaddlePaddle#2983)

* Fix performance degradation bug of custom_all_reduce (PaddlePaddle#2981)

* FA3 fix bug (PaddlePaddle#2987)

* polish code for prefill restrictions (PaddlePaddle#2991)

* [Feature] Support block scheduler v1 for FD (PaddlePaddle#2928)

* Support FD block scheduler v1

* Support FD block scheduler v1

* Support FD block scheduler v1

* Fix according to copilot review

* Fix according to review

* Remove is_dummy

* Fix bug when real_bsz=1

* Fix infer first token cost time

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* update (PaddlePaddle#2978)

* [Code Simplification] fix init_distributed_environment() (PaddlePaddle#2982)

* support c4 attn && fix cache

* fix chunk_prefill

* [benchmark] add quantization for benchmark yaml (PaddlePaddle#2995)

* [Fix] fix mm ep empty run (PaddlePaddle#2999)

* add ci reuse action (PaddlePaddle#2968)

* add ci reuse action

* fix code formatting

* update

* [Feature] multi-source download (PaddlePaddle#2986)

* multi-source download

* multi-source download

* huggingface download revision

* requirement

* style

* add revision arg

* test

* pre-commit

* [LLM] update function name (PaddlePaddle#2985)

* [LLM] update function name

* [BugFix] fix multinode deployment (PaddlePaddle#2977)

* Update benchmark tools (PaddlePaddle#3004)

* update benchmark tools

* update benchmark tools

* update flake8 version to support pre-commit in python3.12 (PaddlePaddle#3000)

* update flake8 version to support pre-commit in python3.12

* polish code

* [Feature] multi source download (PaddlePaddle#3005)

* multi-source download

* multi-source download

* huggingface download revision

* requirement

* style

* add revision arg

* test

* pre-commit

* Change default download

* change requirements.txt

* modify English Documentation

* documentation

* [GCU] Update to develop (PaddlePaddle#2988)

* [Model] Provide clearer error for missing KV cache quantization scales (PaddlePaddle#3007)

* [Feature] Support_eplb (PaddlePaddle#2997)

* [Feature] support_eplb

* [Feature] support_eplb

* [Fix] fix mm ep

* Update setup.py

* [feat] add disable_chat_template in chat api as a substitute for previous raw_request (PaddlePaddle#3023)

* [feat] add disable_chat_template in chat api as a substitute for previous raw_request

* [fix] pre-commit code check

---------

Co-authored-by: GoldPancake <56388518+Deleter-D@users.noreply.github.com>
Co-authored-by: gaoziyuan <88373061+gzy19990617@users.noreply.github.com>
Co-authored-by: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com>
Co-authored-by: Ryan <zihaohuang@aliyun.com>
Co-authored-by: lizhenyun01 <1500424927@qq.com>
Co-authored-by: chen <103103266+ckl117@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: freeliuzc <lzc842650834@gmail.com>
Co-authored-by: Zero Rains <linjunlu@zerorains.top>
Co-authored-by: zhink <33270771+zhink@users.noreply.github.com>
Co-authored-by: chenjian <1435317881@qq.com>
Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
Co-authored-by: xiegegege <46314656+xiegegege@users.noreply.github.com>
Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com>
Co-authored-by: YUNSHEN XIE <1084314248@qq.com>
Co-authored-by: Yzc216 <101054010+Yzc216@users.noreply.github.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
Co-authored-by: EnflameGCU <118410644+EnflameGCU@users.noreply.github.com>
Co-authored-by: littledgg <61149469+littledgg@users.noreply.github.com>
Co-authored-by: 李泳桦 <39643373+liyonghua0910@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants