Skip to content

Conversation

@Wanglongzhi2001
Copy link
Collaborator

@Wanglongzhi2001 Wanglongzhi2001 commented Oct 24, 2025

Motivation

In some scenario, the input token of MoE will be very large, to reduce the activation memory of MoE, this pr develop chunked MoE to split the input of MoE into multi parts.

Modifications

new feature

Usage or Command

Just add two extra params: enable-chunked-moe and chunked-moe-size:

python -m fastdeploy.entrypoints.openai.multi_api_server \
       --ports "8280,8281,8282,8283,8284,8285,8286,8287" \
       --metrics-ports "8480,8481,8482,8488,8484,8485,8486,8487" \
       --num-servers 8 \
       --args --tensor-parallel-size 1 \
       --data-parallel-size 8 \
       --enable-expert-parallel \
       --enable-chunked-moe \
       --chunked-moe-size 1024 \
       --engine-worker-queue-port "8380,8381,8382,8383,8384,8385,8386,8387" \
       --max-model-len 16384 \
       --max-num-seqs 128 \
       --gpu-memory-utilization 0.9 \
       --model "$MODEL_PATH" \
       --num-gpu-blocks-override 12288 \
       --enable-mm-output \
       --prealloc-dec-block-slot-num-threshold 15 \
       --no-enable-prefix-caching \
       --quantization block_wise_fp8 \
       --ips $ip_list \

Accuracy Tests

Don't affect model outputs.

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Oct 24, 2025

Thanks for your contribution!

Copy link
Collaborator

@RichardWooSJTU RichardWooSJTU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to force using low_latency_dispatch if the chunk size is limited to 256?

if i == num_chunk - 1:
out[i * chunk_size:,:] = self.quant_method.apply(self, x[i * chunk_size:,:], gate)
else:
out[i * chunk_size:(i+1)*chunk_size,:] = self.quant_method.apply(self, x[i * chunk_size:(i+1)*chunk_size,:], gate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using paddle.split before loop instead of slice in each step, which will cause lot of launch overhead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same to out, which can use paddle.concat after loop

assert out is not None, "FusedMOE forward got error result"
return out

def forward_chunked_moe(self, x, gate):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数现在ci能测到吗,如果测不到需要补充单测

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,稍后补充单测

Comment on lines 617 to 620
if num_chunk == max_num_chunk:
for i in range(num_chunk):
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else: # num_chunk < max_num_chunk
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行可以删去吧,冗余逻辑

"""
out = self.quant_method.apply(self, x, gate)
out = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
Copy link
Collaborator

@yuanlehome yuanlehome Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非EP下也可以用吧,能否写得通用些

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces chunked MoE (Mixture of Experts) support, enabling MoE layers to process inputs in configurable chunks to optimize memory usage and synchronization across distributed ranks.

Key changes:

  • Added enable_chunked_moe and chunked_moe_size configuration parameters
  • Implemented chunked MoE forward pass that splits inputs into chunks and synchronizes across ranks
  • Added distributed status collection to coordinate chunk sizes across ranks

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
tests/layers/test_chunked_moe.py New test file validating chunked MoE functionality with multi-rank setup
fastdeploy/worker/worker_process.py Added CLI arguments for chunked MoE configuration
fastdeploy/worker/model_runner_base.py Introduced dataclasses for tracking distributed status
fastdeploy/worker/gpu_model_runner.py Implemented distributed status collection and chunk size coordination
fastdeploy/model_executor/layers/moe/moe.py Added chunked MoE forward pass implementation
fastdeploy/engine/engine.py Updated worker service to pass chunked MoE configuration
fastdeploy/engine/async_llm.py Updated async worker service to pass chunked MoE configuration
fastdeploy/engine/args_utils.py Added CLI argument definitions for chunked MoE
fastdeploy/config.py Added chunked MoE configuration fields to ParallelConfig

You can also share your feedback on Copilot code review for a chance to win a $100 gift card. Take the survey.

if i <= self.fd_config.parallel_config.moe_num_chunk - 1:
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else:
self.quant_method.apply(self, x, gate)
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of quant_method.apply() is discarded when i > moe_num_chunk - 1. This appears to be a synchronization mechanism but wastes computation. If synchronization is needed, consider using explicit barrier operations instead of dummy computation. If this is intentional for performance reasons, add a comment explaining why.

Suggested change
self.quant_method.apply(self, x, gate)
# Synchronization is required here to ensure all ranks are aligned.
# Replacing dummy computation with an explicit barrier for clarity and efficiency.
paddle.distributed.barrier()

Copilot uses AI. Check for mistakes.
Comment on lines 636 to 675
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
out = self.quant_method.apply(self, x, gate)
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When token_num <= chunk_size, the same computation is repeated max_moe_num_chunk times, with only the last result being used. This appears to be for cross-rank synchronization but wastes significant compute resources. Consider using explicit communication primitives (e.g., paddle.distributed.barrier()) instead of redundant computation.

Suggested change
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
out = self.quant_method.apply(self, x, gate)
out = self.quant_method.apply(self, x, gate)
paddle.distributed.barrier()

Copilot uses AI. Check for mistakes.

@dataclass
class DistributedOut:
if_only_decode: bool = None
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using None as default for a boolean field is unconventional in dataclasses. Consider using Optional[bool] = None with the proper import, or use a default boolean value like False if a default state can be defined.

Suggested change
if_only_decode: bool = None
if_only_decode: Optional[bool] = None

Copilot uses AI. Check for mistakes.
"--chunked-moe-size",
type=int,
default=EngineArgs.chunked_moe_size,
help="chunked size of moe input.",
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent capitalization in help text. The help text should start with a capital letter or follow the pattern of other help messages in the file. Change to 'Chunked size of moe input.' or 'Chunk size of MoE input.'

Suggested change
help="chunked size of moe input.",
help="Chunked size of MoE input.",

Copilot uses AI. Check for mistakes.

return out

def forward_normal(self, x, gate):
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for forward_normal method. Add documentation explaining this is the standard non-chunked MoE forward pass.

Suggested change
def forward_normal(self, x, gate):
def forward_normal(self, x, gate):
"""
Standard non-chunked MoE forward pass.
Args:
x (Tensor): Input tensor to the MoE layer.
gate (nn.Layer): Gating layer for expert selection.
Returns:
Tensor: Output tensor after applying the MoE experts.
"""

Copilot uses AI. Check for mistakes.
self.quant_method = MockQuantMethod()

def forward(self, x, gate):
return self.quant_method.apply(x, gate)
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mock implementation incorrectly passes only 2 arguments to quant_method.apply(), but the actual implementation (line 630, 632, 637 in moe.py) passes 3 arguments: self, x, gate. This makes the mock inconsistent with the real code and may not catch interface bugs. Update to return self.quant_method.apply(self, x, gate).

Suggested change
return self.quant_method.apply(x, gate)
return self.quant_method.apply(self, x, gate)

Copilot uses AI. Check for mistakes.
Comment on lines 85 to 86
def apply(self, layer, x, gate):
return x
Copy link

Copilot AI Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock MockQuantMethod.apply ignores the layer and gate parameters and simply returns x. This doesn't validate that the actual chunked MoE logic correctly passes these parameters to the quant method. Consider adding assertions to verify the parameters are passed correctly.

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

codecov-commenter commented Nov 28, 2025

Codecov Report

❌ Patch coverage is 97.22222% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@051b82b). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/moe/moe.py 90.90% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #4575   +/-   ##
==========================================
  Coverage           ?   59.74%           
==========================================
  Files              ?      324           
  Lines              ?    39669           
  Branches           ?     5965           
==========================================
  Hits               ?    23701           
  Misses             ?    14087           
  Partials           ?     1881           
Flag Coverage Δ
GPU 59.74% <97.22%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@RichardWooSJTU RichardWooSJTU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


if_only_decode = dist_status.if_only_decode
if self.fd_config.parallel_config.enable_chunked_moe:
self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

怎么不传递forward_meta到moe layer了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +543 to +546
self.enable_chunked_moe = False
self.chunked_moe_size = 256
self.max_moe_num_chunk = 1
self.moe_num_chunk = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunked moe相关的参数有点多了,多于2个的,建议打包成字典

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,跟着下个forward meta的pr一起改吧,今天提测需要

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gongshaotian approve下吧~

x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk

for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_moe_num_chunk是动态变化的嘛,我怎么感觉这个进不了cudaGraph?

Copy link
Collaborator Author

@Wanglongzhi2001 Wanglongzhi2001 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_moe_num_chunk是动态变化的嘛,我怎么感觉这个进不了cudaGraph?

目前只在生图上用,并且设置的 chunked size 1024以上,因此在 max_moe_num_chunk 变化的时候是不进 cudagraph 的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_moe_num_chunk是动态变化的嘛,我怎么感觉这个进不了cudaGraph?

我理解一般需要这个功能的场景也是 token 数目非常大的时候,也是不兼容 cudagraph 的场景。我可以下个 pr 加个 assert,这个和 cudagraph 不能同时打开

Copy link
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Wanglongzhi2001 Wanglongzhi2001 merged commit add524d into PaddlePaddle:develop Dec 1, 2025
15 of 19 checks passed
@Wanglongzhi2001 Wanglongzhi2001 deleted the chunked_moe branch January 29, 2026 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants