Skip to content

Conversation

@Lmywl
Copy link
Contributor

@Lmywl Lmywl commented Aug 11, 2025

背景:cudagraph 捕获过程中的张量地址管理
目的:将attention模块的输出前置,便于cudagraph捕获时的张量地址处理

@paddle-bot
Copy link

paddle-bot bot commented Aug 11, 2025

Thanks for your contribution!

@gongshaotian
Copy link
Collaborator

麻烦再丰富一下PR描述,说明一下改造的背景、目标

YuanRisheng
YuanRisheng previously approved these changes Aug 21, 2025
Copilot AI review requested due to automatic review settings August 26, 2025 08:02
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

gongshaotian
gongshaotian previously approved these changes Aug 26, 2025
Copy link
Collaborator

@DrRyanHuang DrRyanHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了下面俩直接能改的,还有 append_attention 申明为自定义算子的这部分,Outputs 也要改一下
因为咱 append_attention 不是不需要输出 qkv_out 了嘛,所以删掉它

PD_BUILD_STATIC_OP(append_attention)
    .Inputs({"qkv",
    ......
    .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})  # <--- 这一行
    .SetInplaceMap({{"key_cache", "key_cache_out"},

改成

    .Outputs({"fmha_out", "key_cache_out", "value_cache_out"})

PS: 自定义算子的注册出问题总是直接抛出这种异常,后续主框架也要添加更多上下文信息

terminate called after throwing an instance of 'std:bad_array_new_length'
	what(): std: : bad_array_new_length

cc @zyfncg @SigureMo

const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &res,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
paddle::Tensor &res,
paddle::Tensor &fmha_out,

Comment on lines 1126 to 1141
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"rms_norm_eps: float"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 rms_norm_eps 的顺序往前移动一下

Suggested change
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
"rms_norm_eps: float"})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})

@Lmywl Lmywl dismissed stale reviews from gongshaotian and YuanRisheng via 440a44d August 28, 2025 06:21
Copy link
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gongshaotian gongshaotian merged commit e93d4cf into PaddlePaddle:develop Aug 28, 2025
13 of 16 checks passed
Jiang-Jia-Jun pushed a commit that referenced this pull request Oct 17, 2025
* rm inplace info && to(gpu)

* update append_attention

* unpin paddle version

* add full_cuda_graph=False

* add blank line

---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants