Skip to content

Conversation

@zhupengyang
Copy link
Collaborator

@zhupengyang zhupengyang commented Nov 5, 2025

Motivation

  • FD_EP_TP_STRATEGY="all_reduce" or "all_to_all"
image

Modifications

Usage or Command

export FD_EP_TP_STRATEGY="all_to_all"
python -m fastdeploy.entrypoints.openai.api_server \
  --model <model_path> \
  --port 8188 \
  --tensor-parallel-size 8 \
  --enable-expert-parallel \
  --data-parallel-size 8 \
  --max-model-len 32768 \
  --max-num-seqs 64 \
  --quantization "wint4" \
  --engine-worker-queue-port "8023,8033,8043,8053,8063,8073,8083,8093" \
  --gpu-memory-utilization 0.9 \
  --load-choices "default"

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Nov 5, 2025

Thanks for your contribution!

# ep+tp split strategy
# 0: qkv_linear + attn + out_linear + allreduce
# 1: allgather + qkv_linear + attn + all2all + out_linear
self.ep_tp_split_mode = int(os.getenv("EP_TP_SPLIT_MODE", 0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增环境变量放到envs.py中,命名以FD_开头,并提供注释

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我也建议是字符串类型,值可以是 all_reduce 或 all_to_all

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
)
xpu_forward_meta.token_num = token_num
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_meta中添加字段需要在data class增加,就这一行而言,不建议写,直接通过xpu_forward_meta.ids_remove_padding.shape[0]就是token_num

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 271 to 278
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
if fd_config.parallel_config.ep_tp_split_mode == 1:
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
if k in weight_list:
no_tp_action_keys.append(k)
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里记上一个TODO,V1 loader逻辑是不会走到这里的,需要适配V1 loader

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

适配V1 loader的时候会再统一处理

Comment on lines 238 to 248
out = norm_out[0].astype(x_dtype)
residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None

if self.split_x:
residual_out = self.split(residual_out)
if self.allgather_out:
out = self.allgather(out, forward_meta.token_num)

if residual_input is None:
return out
else:
return norm_out[0].astype(x_dtype)
return out, residual_out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是可以写到linear层里去,放到这里太奇怪了,norm层感知到tp/ep,不合理

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要切分residual_out,只能在norm中感知

Comment on lines 832 to 835
if self.split_token:
self.num_heads = fd_config.model_config.num_attention_heads
else:
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_heads这块是不是可以删掉?row parallel linear不需要这个变量?后面需要self.num_heads * self.head_dim的话,是不是self.hidden_size就可以 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +876 to +875
if token_num_pad > token_num:
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_new[:token_num, :] = x
x = x_new
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一块不用if判断了吧,否则cudagraph捕获不了

reduce_results: bool = True,
skip_quant: bool = False,
weight_dtype="",
layer_id: int = -1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增参数,建议所有model.py里都传一下吧,默认为-1不合理

Copy link
Collaborator Author

@zhupengyang zhupengyang Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前方案需要在layer层内部确定在组网的什么位置。

所有的out_linear都是用的RowParallelLinear,通过传入layer_id确定位置,这种方式比较trick。

比如mlp和shared_expert都是用的RowParallelLinear,都传入layer_id的话,就还需要其他信息来确定当前layer是不是out_linear

可能得看下怎么更准确得描述layer在模型的哪个位置,比如传入name string来感知当前位置

token_num_per_rank = out.shape[0]
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs if token_num is None else multi_outs[:token_num, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也会导致cudagraph捕获不了,默认就multi_outs[:token_num, :] ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@hong19860320 hong19860320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome
Copy link
Collaborator

PR描述里给一下启动脚本示例吧~

Comment on lines +240 to +243
if self.split_x:
residual_out = self.split(residual_out)
if self.allgather_out:
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩分支场景现在有单测能测到么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有xpu的单测。gpu目前跑不到这里,需要gpu同学自己适配并验证

@EmmonsCurse EmmonsCurse merged commit b54eb7a into PaddlePaddle:develop Nov 6, 2025
26 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants