Skip to content

Conversation

@ddchenhao66
Copy link
Collaborator

Motivation

XPU支持多模prefix cache

Modifications

  1. config文件删除多模模型XPU默认关闭prefix cache的代码;
  2. xpu_model_runner中增加多模prefix cache的支持

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


ddchenhao66 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

这个PR为XPU添加了多模态前缀缓存(multimodal prefix cache)的支持。主要变更包括移除配置文件中对XPU平台默认禁用prefix cache的限制,以及在xpu_model_runner中实现encoder cache功能以缓存视觉特征。

关键变更:

  • 移除了XPU平台多模态模型默认禁用prefix cache的配置限制
  • 添加了encoder_cache机制来缓存和重用提取的视觉特征
  • 重构了多模态输入处理逻辑,将其提取到_apply_mm_inputs方法中
  • 优化了CUDAGraph设备类型检查逻辑

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
fastdeploy/engine/args_utils.py 删除了XPU平台多模态模型自动禁用prefix cache的配置代码
fastdeploy/config.py 重构了CUDAGraph设备类型检查,使用current_platform.is_cuda()替代直接访问device_config
fastdeploy/worker/xpu_model_runner.py 添加了encoder_cache初始化和多模态prefix cache支持,包括新增的get_chunked_inputsbatch_uncached_inputsscatter_and_cache_features_apply_mm_inputs方法

Comment on lines +329 to +334
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for calculating actual_image_token_num for "qwen" model type attempts to access vision_inputs["image_patch_id"] and vision_inputs["video_patch_id"], but these fields are not set in the vision_inputs dictionary in either the batch_uncached_inputs method or the get_chunked_inputs method. This will cause a KeyError at runtime for qwen models.

Looking at the GPU model runner implementation, these patch IDs should be provided in the vision_inputs. Please ensure that these fields are properly populated in the vision_inputs dict before this calculation, or modify the logic to retrieve them from an appropriate source (e.g., from request.multimodal_inputs or self.model_config).

Copilot uses AI. Check for mistakes.
self.encoder_cache: dict[str, paddle.Tensor] = {}
else:
self.encoder_cache = None

Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace detected. Please remove the trailing whitespace for code cleanliness.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines +270 to +336
inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
)
multi_vision_inputs["grid_thw_lst"].extend(
inputs["grid_thw"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["cu_seqlens"].extend(
inputs["vit_seqlen"][request.num_image_start : request.num_image_end]
)
multi_vision_inputs["vit_position_ids_lst"].extend(
inputs["vit_position_ids"][request.num_image_start : request.num_image_end]
)
else:
vision_inputs = inputs
if self.encoder_cache:
(
vision_inputs["input_ids"],
vision_inputs["token_type_ids"],
vision_inputs["image_type_ids"],
vision_inputs["images"],
vision_inputs["grid_thw"],
vision_inputs["mm_hashes"],
) = self.batch_uncached_inputs(request)
if len(vision_inputs["mm_hashes"]) > 0:
# uncached multimodal inputs exist
image_features = self.extract_vision_features(vision_inputs)
self.scatter_and_cache_features(image_features, vision_inputs)

full_image_features_lst = []
for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]:
feature = self.encoder_cache[mm_hash].cuda()
full_image_features_lst.append(feature)
image_features = paddle.concat(full_image_features_lst, axis=0)
else:
(
input_ids,
token_type_ids,
image_type_ids,
images,
grid_thw,
mm_hashes,
) = self.get_chunked_inputs(request)
vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64)
vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
vision_inputs["images"] = paddle.to_tensor(
images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16"
)
vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64)
vision_inputs["mm_hashes"] = mm_hashes

image_features = self.extract_vision_features(vision_inputs)

# part of the first image may be already cached
if "ernie" in self.model_config.model_type:
actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id)
elif "qwen" in self.model_config.model_type:
actual_image_token_num = paddle.sum(
vision_inputs["input_ids"] == vision_inputs["image_patch_id"]
) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"])
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
self.share_inputs["image_features"] = image_features[-actual_image_token_num:]

Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing handling for the case when request.with_image is False. The old code (before refactoring) explicitly set self.share_inputs["image_features"] = None when there were no images. Without this, when multimodal is enabled but a request batch contains no images, self.share_inputs["image_features"] may retain stale data from previous batches.

The recommended fix, following the GPU model runner pattern (gpu_model_runner.py line 566), is to initialize self.share_inputs["image_features"] = None at the beginning of the insert_tasks_v1 method (after line 358), rather than in this method.

Copilot uses AI. Check for mistakes.
@kevincheng2
Copy link
Collaborator

_apply_mm_inputs 方法我最近重写了一下,支持了多模 prefill 多batch 和修复了 encoder cache的bug,这个pr可以先合入,后边可以看下需不需要更新


inputs = request.multimodal_inputs
if request.with_image:
if envs.FD_ENABLE_MAX_PREFILL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

envs.FD_ENABLE_MAX_PREFILL 是 paddle_ocr vl 的逻辑,是支持prefill多batch的,这里在xpu上有没有问题可能要看下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

envs.FD_ENABLE_MAX_PREFILL 如果和 paddle_ocr vl模型绑定,那是不是if envs.FD_ENABLE_MAX_PREFILL:直接再加上模型判断,或者封装成一个函数?xpu也支持了extract_vision_features_paddleocr函数,里面也涉及到了FD_ENABLE_MAX_PREFILL,这次应该也要加上

@codecov-commenter
Copy link

codecov-commenter commented Dec 3, 2025

Codecov Report

❌ Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@f458cc5). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/config.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5356   +/-   ##
==========================================
  Coverage           ?   59.34%           
==========================================
  Files              ?      324           
  Lines              ?    40058           
  Branches           ?     6051           
==========================================
  Hits               ?    23772           
  Misses             ?    14402           
  Partials           ?     1884           
Flag Coverage Δ
GPU 59.34% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@EmmonsCurse EmmonsCurse merged commit 4e8096b into PaddlePaddle:develop Dec 3, 2025
12 of 17 checks passed
@paddle-bot
Copy link

paddle-bot bot commented Dec 3, 2025

Thanks for your contribution!

1 similar comment
@paddle-bot
Copy link

paddle-bot bot commented Dec 3, 2025

Thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants