Skip to content

[Models] Add forward_meta to moe models' forward function#5138

Merged
Wanglongzhi2001 merged 9 commits intoPaddlePaddle:developfrom
Wanglongzhi2001:add_forward_meta
Dec 4, 2025
Merged

[Models] Add forward_meta to moe models' forward function#5138
Wanglongzhi2001 merged 9 commits intoPaddlePaddle:developfrom
Wanglongzhi2001:add_forward_meta

Conversation

@Wanglongzhi2001
Copy link
Collaborator

@Wanglongzhi2001 Wanglongzhi2001 commented Nov 20, 2025

Motivation

In some scenarios, such as chunked MoE, we need to update the state of MoE. It's reasonable to write this state variable in forward_meta, so we need to add the forward_meta parameter to the FusedMoE's forward function.

Modifications

Add forward_meta to moe models' forward function.

Usage or Command

No change

Accuracy Tests

No change.

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings November 20, 2025 05:58
@paddle-bot
Copy link

paddle-bot bot commented Nov 20, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds the forward_meta parameter to MoE (Mixture of Experts) models' forward functions to enable access to MoE phase information during forward computation. The change is needed because the forward_meta.moe_phase.phase is used in the fused MoE backend to determine whether to use prefill or decode execution paths.

Key Changes:

  • Updated core MoE layer to accept and propagate forward_meta parameter through the computation pipeline
  • Modified all MoE and MLP forward methods across multiple model architectures to include forward_meta parameter
  • Updated the speculative decoding module to pass forward_meta to empty_input_forward calls

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/moe/moe.py Added forward_meta parameter to FusedMoE forward method and propagated it to quant_method.apply calls
fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py Added forward_meta parameter to MoEMethodBase.apply and uses it to check moe_phase
fastdeploy/model_executor/models/qwen3moe.py Updated Qwen3MoeBlock and Qwen3MLP forward signatures to include forward_meta
fastdeploy/model_executor/models/qwen2.py Updated Qwen2MLP forward signature to include forward_meta
fastdeploy/model_executor/models/gpt_oss.py Updated GptOssMoe forward signature to include and propagate forward_meta
fastdeploy/model_executor/models/glm4_moe.py Updated Glm4MoeMLP and Glm4Moe forward signatures to include forward_meta
fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py Updated Ernie4_5_VLMoE and related classes to include forward_meta; updated empty_input_forward calls
fastdeploy/model_executor/models/ernie4_5_mtp.py Updated empty_input_forward signature to accept forward_meta parameter
fastdeploy/model_executor/models/ernie4_5_moe.py Updated Ernie4_5_MLP and Ernie4_5_MoE forward signatures to include forward_meta; updated empty_input_forward calls
fastdeploy/model_executor/models/deepseek_v3.py Updated DeepSeekV3MLP and DeepSeekV3MoE forward signatures to include forward_meta
fastdeploy/spec_decode/mtp.py Updated empty_input_forward call to pass forward_meta parameter
Comments suppressed due to low confidence (1)

fastdeploy/model_executor/layers/moe/moe.py:615

            out = self.forward_split_allgather(x, gate, forward_meta)

def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Ernie4_5_VLMLP (which inherits from Ernie4_5_MLP) and now requires forward_meta as the second parameter. The call should be: shared_experts_out = self.shared_experts(hidden_states, forward_meta)

Suggested change
shared_experts_out = self.shared_experts(hidden_states)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)

Copilot uses AI. Check for mistakes.
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
out = self.forward_split_allgather(x, gate, forward_meta)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward_split_allgather method is being called with forward_meta parameter (line 615), but its method signature at line 576 doesn't include this parameter. The signature should be updated to: def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta): and the parameter should be passed to quant_method.apply at line 591.

Copilot uses AI. Check for mistakes.
Comment on lines 356 to 357
forward_meta=forward_meta,
hidden_states=hidden_states,
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent parameter ordering: The method signature has forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta) (line 98), but the call uses forward_meta=forward_meta, hidden_states=hidden_states (lines 356-357). While this works with keyword arguments, it's inconsistent with the positional order. Consider using positional order: self.mlp(hidden_states, forward_meta) for better consistency with the method signature.

Suggested change
forward_meta=forward_meta,
hidden_states=hidden_states,
hidden_states,
forward_meta,

Copilot uses AI. Check for mistakes.
forward_meta=forward_meta,
)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Ernie4_5_MLP which now requires forward_meta as the second parameter (line 98). The call should be: s_x = self.shared_experts(hidden_states, forward_meta)

Suggested change
s_x = self.shared_experts(hidden_states)
s_x = self.shared_experts(hidden_states, forward_meta)

Copilot uses AI. Check for mistakes.

def forward(self, x):
def forward(self, x, forward_meta):
shared_experts_out = self.shared_experts(x)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Glm4MoeMLP which now requires forward_meta as the second parameter (line 88). The call should be: shared_experts_out = self.shared_experts(x, forward_meta)

Suggested change
shared_experts_out = self.shared_experts(x)
shared_experts_out = self.shared_experts(x, forward_meta)

Copilot uses AI. Check for mistakes.
Comment on lines 192 to 193
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in both shared_experts and experts calls. Both methods now require forward_meta. The calls should be:

  • shared_experts_out = self.shared_experts(hidden_states, forward_meta)
  • moe_out = self.experts(hidden_states, self.gate, forward_meta)
Suggested change
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)
moe_out = self.experts(hidden_states, self.gate, forward_meta)

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

codecov-commenter commented Nov 26, 2025

Codecov Report

❌ Patch coverage is 77.04918% with 14 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@209006e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/moe/moe.py 77.77% 0 Missing and 2 partials ⚠️
...del_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py 80.00% 2 Missing ⚠️
fastdeploy/model_executor/models/gpt_oss.py 33.33% 2 Missing ⚠️
fastdeploy/worker/gpu_model_runner.py 71.42% 2 Missing ⚠️
fastdeploy/model_executor/models/deepseek_v3.py 83.33% 1 Missing ⚠️
fastdeploy/model_executor/models/ernie4_5_moe.py 83.33% 1 Missing ⚠️
fastdeploy/model_executor/models/ernie4_5_mtp.py 50.00% 1 Missing ⚠️
fastdeploy/model_executor/models/glm4_moe.py 83.33% 1 Missing ⚠️
fastdeploy/model_executor/models/qwen3moe.py 83.33% 1 Missing ⚠️
fastdeploy/spec_decode/mtp.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5138   +/-   ##
==========================================
  Coverage           ?   59.50%           
==========================================
  Files              ?      325           
  Lines              ?    40273           
  Branches           ?     6097           
==========================================
  Hits               ?    23965           
  Misses             ?    14402           
  Partials           ?     1906           
Flag Coverage Δ
GPU 59.50% <77.04%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Wanglongzhi2001 Wanglongzhi2001 force-pushed the add_forward_meta branch 2 times, most recently from c1fc576 to 43f6457 Compare December 3, 2025 13:19
gongshaotian
gongshaotian previously approved these changes Dec 4, 2025
Comment on lines +146 to +148
# chunked MoE related
moe_num_chunk: int = 1
max_moe_num_chunk: int = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议每个字段都加个解释,方便长期维护

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个 PR 改正

Comment on lines +278 to +280
self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
else:
self.fd_config.parallel_config.moe_num_chunk = 1
self.forward_meta.moe_num_chunk = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ep phase 的修改是不是也放在 meta里

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1,可以再单独提个pr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这部分还涉及到 EP Runner 的修改,得设计下,下个 PR 改

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,麻烦后面贴下pr链接

yuanlehome
yuanlehome previously approved these changes Dec 4, 2025
freeliuzc
freeliuzc previously approved these changes Dec 4, 2025
@Wanglongzhi2001 Wanglongzhi2001 merged commit 5cd17fd into PaddlePaddle:develop Dec 4, 2025
15 of 18 checks passed
@fmiao2372 fmiao2372 mentioned this pull request Dec 5, 2025
2 tasks
liyonghua0910 pushed a commit to liyonghua0910/FastDeploy that referenced this pull request Dec 5, 2025
…le#5138)

* [Models] Add forward_meta to moe models' forward function

* fix missing param

* fix

* fix

* fix forward_meta

* fix test and remove chunked MoE releated in config

* fix test

* fix

* fix
@Wanglongzhi2001 Wanglongzhi2001 deleted the add_forward_meta branch January 29, 2026 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants